From ed8c9f012493a861f5993c55f2d8206d6368e944 Mon Sep 17 00:00:00 2001 From: Eric Date: Wed, 8 Nov 2023 21:24:35 +0000 Subject: [PATCH] Revert "Add Statement (#2294) (#2318) (#2319)" This reverts commit b3c2e94cac59bf324440accddbbc7941fa7c0c07. Signed-off-by: Eric --- spark/build.gradle | 5 +- .../execution/session/InteractiveSession.java | 73 +--- .../sql/spark/execution/session/Session.java | 21 -- .../execution/session/SessionManager.java | 10 +- .../spark/execution/session/SessionModel.java | 30 +- .../spark/execution/session/SessionState.java | 4 - .../execution/statement/QueryRequest.java | 15 - .../spark/execution/statement/Statement.java | 85 ----- .../execution/statement/StatementId.java | 23 -- .../execution/statement/StatementModel.java | 194 ---------- .../execution/statement/StatementState.java | 38 -- .../statestore/SessionStateStore.java | 87 +++++ .../execution/statestore/StateModel.java | 30 -- .../execution/statestore/StateStore.java | 149 -------- .../session/InteractiveSessionTest.java | 27 +- .../execution/session/SessionManagerTest.java | 15 +- .../statement/StatementStateTest.java | 20 - .../execution/statement/StatementTest.java | 356 ------------------ .../statestore/SessionStateStoreTest.java | 42 +++ 19 files changed, 169 insertions(+), 1055 deletions(-) delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java delete mode 100644 spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java delete mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java diff --git a/spark/build.gradle b/spark/build.gradle index 15f1e200e0..49ff96bec5 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -119,9 +119,8 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', // ignore because XContext IOException - 'org.opensearch.sql.spark.execution.statestore.StateStore', - 'org.opensearch.sql.spark.execution.session.SessionModel', - 'org.opensearch.sql.spark.execution.statement.StatementModel' + 'org.opensearch.sql.spark.execution.statestore.SessionStateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index e33ef4245a..620e46b9be 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,10 +6,6 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; -import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; -import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import java.util.Optional; import lombok.Builder; @@ -18,11 +14,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statement.QueryRequest; -import org.opensearch.sql.spark.execution.statement.Statement; -import org.opensearch.sql.spark.execution.statement.StatementId; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; /** * Interactive session. @@ -35,8 +27,9 @@ public class InteractiveSession implements Session { private static final Logger LOG = LogManager.getLogger(); private final SessionId sessionId; - private final StateStore stateStore; + private final SessionStateStore sessionStateStore; private final EMRServerlessClient serverlessClient; + private SessionModel sessionModel; @Override @@ -48,7 +41,7 @@ public void open(CreateSessionRequest createSessionRequest) { sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - createSession(stateStore).apply(sessionModel); + sessionStateStore.create(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -56,67 +49,13 @@ public void open(CreateSessionRequest createSessionRequest) { } } - /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = getSession(stateStore).apply(sessionModel.getId()); + Optional model = sessionStateStore.get(sessionModel.getSessionId()); if (model.isEmpty()) { - throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); + throw new IllegalStateException("session not exist. " + sessionModel.getSessionId()); } else { serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); } } - - /** Submit statement. If submit successfully, Statement in waiting state. */ - public StatementId submit(QueryRequest request) { - Optional model = getSession(stateStore).apply(sessionModel.getId()); - if (model.isEmpty()) { - throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); - } else { - sessionModel = model.get(); - if (!END_STATE.contains(sessionModel.getSessionState())) { - StatementId statementId = newStatementId(); - Statement st = - Statement.builder() - .sessionId(sessionId) - .applicationId(sessionModel.getApplicationId()) - .jobId(sessionModel.getJobId()) - .stateStore(stateStore) - .statementId(statementId) - .langType(LangType.SQL) - .query(request.getQuery()) - .queryId(statementId.getId()) - .build(); - st.open(); - return statementId; - } else { - String errMsg = - String.format( - "can't submit statement, session should not be in end state, " - + "current session state is: %s", - sessionModel.getSessionState().getSessionState()); - LOG.debug(errMsg); - throw new IllegalStateException(errMsg); - } - } - } - - @Override - public Optional get(StatementId stID) { - return StateStore.getStatement(stateStore) - .apply(stID.getId()) - .map( - model -> - Statement.builder() - .sessionId(sessionId) - .applicationId(model.getApplicationId()) - .jobId(model.getJobId()) - .statementId(model.getStatementId()) - .langType(model.getLangType()) - .query(model.getQuery()) - .queryId(model.getQueryId()) - .stateStore(stateStore) - .statementModel(model) - .build()); - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index 4d919d5e2e..ec9775e60a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -5,11 +5,6 @@ package org.opensearch.sql.spark.execution.session; -import java.util.Optional; -import org.opensearch.sql.spark.execution.statement.QueryRequest; -import org.opensearch.sql.spark.execution.statement.Statement; -import org.opensearch.sql.spark.execution.statement.StatementId; - /** Session define the statement execution context. Each session is binding to one Spark Job. */ public interface Session { /** open session. */ @@ -18,22 +13,6 @@ public interface Session { /** close session. */ void close(); - /** - * submit {@link QueryRequest}. - * - * @param request {@link QueryRequest} - * @return {@link StatementId} - */ - StatementId submit(QueryRequest request); - - /** - * get {@link Statement}. - * - * @param stID {@link StatementId} - * @return {@link Statement} - */ - Optional get(StatementId stID); - SessionModel getSessionModel(); SessionId getSessionId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 217af80caf..3d0916bac8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -10,7 +10,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; /** * Singleton Class @@ -19,14 +19,14 @@ */ @RequiredArgsConstructor public class SessionManager { - private final StateStore stateStore; + private final SessionStateStore stateStore; private final EMRServerlessClient emrServerlessClient; public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId()) - .stateStore(stateStore) + .sessionStateStore(stateStore) .serverlessClient(emrServerlessClient) .build(); session.open(request); @@ -34,12 +34,12 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); + Optional model = stateStore.get(sid); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() .sessionId(sid) - .stateStore(stateStore) + .sessionStateStore(stateStore) .serverlessClient(emrServerlessClient) .sessionModel(model.get()) .build(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java index 806cdb083e..656f0ec8ce 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -12,16 +12,16 @@ import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; +import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.sql.spark.execution.statestore.StateModel; /** Session data in flint.ql.sessions index. */ @Data @Builder -public class SessionModel extends StateModel { +public class SessionModel implements ToXContentObject { public static final String VERSION = "version"; public static final String TYPE = "type"; public static final String SESSION_TYPE = "sessionType"; @@ -73,27 +73,6 @@ public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { .sessionId(new SessionId(copy.sessionId.getSessionId())) .sessionState(copy.sessionState) .datasourceName(copy.datasourceName) - .applicationId(copy.getApplicationId()) - .jobId(copy.jobId) - .error(UNKNOWN) - .lastUpdateTime(copy.getLastUpdateTime()) - .seqNo(seqNo) - .primaryTerm(primaryTerm) - .build(); - } - - public static SessionModel copyWithState( - SessionModel copy, SessionState state, long seqNo, long primaryTerm) { - return builder() - .version(copy.version) - .sessionType(copy.sessionType) - .sessionId(new SessionId(copy.sessionId.getSessionId())) - .sessionState(state) - .datasourceName(copy.datasourceName) - .applicationId(copy.getApplicationId()) - .jobId(copy.jobId) - .error(UNKNOWN) - .lastUpdateTime(copy.getLastUpdateTime()) .seqNo(seqNo) .primaryTerm(primaryTerm) .build(); @@ -161,9 +140,4 @@ public static SessionModel initInteractiveSession( .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) .build(); } - - @Override - public String getId() { - return sessionId.getSessionId(); - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index a4da957f12..509d5105e9 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -5,9 +5,7 @@ package org.opensearch.sql.spark.execution.session; -import com.google.common.collect.ImmutableList; import java.util.Arrays; -import java.util.List; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -19,8 +17,6 @@ public enum SessionState { DEAD("dead"), FAIL("fail"); - public static List END_STATE = ImmutableList.of(DEAD, FAIL); - private final String sessionState; SessionState(String sessionState) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java deleted file mode 100644 index 10061404ca..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java +++ /dev/null @@ -1,15 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statement; - -import lombok.Data; -import org.opensearch.sql.spark.rest.model.LangType; - -@Data -public class QueryRequest { - private final LangType langType; - private final String query; -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java deleted file mode 100644 index 8fcedb5fca..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statement; - -import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; - -import lombok.Builder; -import lombok.Getter; -import lombok.Setter; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.index.engine.DocumentMissingException; -import org.opensearch.index.engine.VersionConflictEngineException; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.rest.model.LangType; - -/** Statement represent query to execute in session. One statement map to one session. */ -@Getter -@Builder -public class Statement { - private static final Logger LOG = LogManager.getLogger(); - - private final SessionId sessionId; - private final String applicationId; - private final String jobId; - private final StatementId statementId; - private final LangType langType; - private final String query; - private final String queryId; - private final StateStore stateStore; - - @Setter private StatementModel statementModel; - - /** Open a statement. */ - public void open() { - try { - statementModel = - submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); - statementModel = createStatement(stateStore).apply(statementModel); - } catch (VersionConflictEngineException e) { - String errorMsg = "statement already exist. " + statementId; - LOG.error(errorMsg); - throw new IllegalStateException(errorMsg); - } - } - - /** Cancel a statement. */ - public void cancel() { - if (statementModel.getStatementState().equals(StatementState.RUNNING)) { - String errorMsg = - String.format("can't cancel statement in waiting state. statement: %s.", statementId); - LOG.error(errorMsg); - throw new IllegalStateException(errorMsg); - } - try { - this.statementModel = - updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); - } catch (DocumentMissingException e) { - String errorMsg = - String.format("cancel statement failed. no statement found. statement: %s.", statementId); - LOG.error(errorMsg); - throw new IllegalStateException(errorMsg); - } catch (VersionConflictEngineException e) { - this.statementModel = - getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); - String errorMsg = - String.format( - "cancel statement failed. current statementState: %s " + "statement: %s.", - this.statementModel.getStatementState(), statementId); - LOG.error(errorMsg); - throw new IllegalStateException(errorMsg); - } - } - - public StatementState getStatementState() { - return statementModel.getStatementState(); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java deleted file mode 100644 index 4baff71493..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statement; - -import lombok.Data; -import org.apache.commons.lang3.RandomStringUtils; - -@Data -public class StatementId { - private final String id; - - public static StatementId newStatementId() { - return new StatementId(RandomStringUtils.random(10, true, true)); - } - - @Override - public String toString() { - return "statementId=" + id; - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java deleted file mode 100644 index c7f681c541..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statement; - -import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; -import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; -import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; - -import java.io.IOException; -import lombok.Builder; -import lombok.Data; -import lombok.SneakyThrows; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.core.xcontent.XContentParserUtils; -import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.statestore.StateModel; -import org.opensearch.sql.spark.rest.model.LangType; - -/** Statement data in flint.ql.sessions index. */ -@Data -@Builder -public class StatementModel extends StateModel { - public static final String VERSION = "version"; - public static final String TYPE = "type"; - public static final String STATEMENT_STATE = "state"; - public static final String STATEMENT_ID = "statementId"; - public static final String SESSION_ID = "sessionId"; - public static final String LANG = "lang"; - public static final String QUERY = "query"; - public static final String QUERY_ID = "queryId"; - public static final String SUBMIT_TIME = "submitTime"; - public static final String ERROR = "error"; - public static final String UNKNOWN = "unknown"; - public static final String STATEMENT_DOC_TYPE = "statement"; - - private final String version; - private final StatementState statementState; - private final StatementId statementId; - private final SessionId sessionId; - private final String applicationId; - private final String jobId; - private final LangType langType; - private final String query; - private final String queryId; - private final long submitTime; - private final String error; - - private final long seqNo; - private final long primaryTerm; - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder - .startObject() - .field(VERSION, version) - .field(TYPE, STATEMENT_DOC_TYPE) - .field(STATEMENT_STATE, statementState.getState()) - .field(STATEMENT_ID, statementId.getId()) - .field(SESSION_ID, sessionId.getSessionId()) - .field(APPLICATION_ID, applicationId) - .field(JOB_ID, jobId) - .field(LANG, langType.getText()) - .field(QUERY, query) - .field(QUERY_ID, queryId) - .field(SUBMIT_TIME, submitTime) - .field(ERROR, error) - .endObject(); - return builder; - } - - public static StatementModel copy(StatementModel copy, long seqNo, long primaryTerm) { - return builder() - .version("1.0") - .statementState(copy.statementState) - .statementId(copy.statementId) - .sessionId(copy.sessionId) - .applicationId(copy.applicationId) - .jobId(copy.jobId) - .langType(copy.langType) - .query(copy.query) - .queryId(copy.queryId) - .submitTime(copy.submitTime) - .error(copy.error) - .seqNo(seqNo) - .primaryTerm(primaryTerm) - .build(); - } - - public static StatementModel copyWithState( - StatementModel copy, StatementState state, long seqNo, long primaryTerm) { - return builder() - .version("1.0") - .statementState(state) - .statementId(copy.statementId) - .sessionId(copy.sessionId) - .applicationId(copy.applicationId) - .jobId(copy.jobId) - .langType(copy.langType) - .query(copy.query) - .queryId(copy.queryId) - .submitTime(copy.submitTime) - .error(copy.error) - .seqNo(seqNo) - .primaryTerm(primaryTerm) - .build(); - } - - @SneakyThrows - public static StatementModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { - StatementModel.StatementModelBuilder builder = StatementModel.builder(); - XContentParserUtils.ensureExpectedToken( - XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { - String fieldName = parser.currentName(); - parser.nextToken(); - switch (fieldName) { - case VERSION: - builder.version(parser.text()); - break; - case TYPE: - // do nothing - break; - case STATEMENT_STATE: - builder.statementState(StatementState.fromString(parser.text())); - break; - case STATEMENT_ID: - builder.statementId(new StatementId(parser.text())); - break; - case SESSION_ID: - builder.sessionId(new SessionId(parser.text())); - break; - case APPLICATION_ID: - builder.applicationId(parser.text()); - break; - case JOB_ID: - builder.jobId(parser.text()); - break; - case LANG: - builder.langType(LangType.fromString(parser.text())); - break; - case QUERY: - builder.query(parser.text()); - break; - case QUERY_ID: - builder.queryId(parser.text()); - break; - case SUBMIT_TIME: - builder.submitTime(parser.longValue()); - break; - case ERROR: - builder.error(parser.text()); - break; - } - } - builder.seqNo(seqNo); - builder.primaryTerm(primaryTerm); - return builder.build(); - } - - public static StatementModel submitStatement( - SessionId sid, - String applicationId, - String jobId, - StatementId statementId, - LangType langType, - String query, - String queryId) { - return builder() - .version("1.0") - .statementState(WAITING) - .statementId(statementId) - .sessionId(sid) - .applicationId(applicationId) - .jobId(jobId) - .langType(langType) - .query(query) - .queryId(queryId) - .submitTime(System.currentTimeMillis()) - .error(UNKNOWN) - .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) - .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) - .build(); - } - - @Override - public String getId() { - return statementId.getId(); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java deleted file mode 100644 index 33f7f5e831..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statement; - -import java.util.Arrays; -import java.util.Map; -import java.util.stream.Collectors; -import lombok.Getter; - -/** {@link Statement} State. */ -@Getter -public enum StatementState { - WAITING("waiting"), - RUNNING("running"), - SUCCESS("success"), - FAILED("failed"), - CANCELLED("cancelled"); - - private final String state; - - StatementState(String state) { - this.state = state; - } - - private static Map STATES = - Arrays.stream(StatementState.values()) - .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); - - public static StatementState fromString(String key) { - if (STATES.containsKey(key)) { - return STATES.get(key); - } - throw new IllegalArgumentException("Invalid statement state: " + key); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java new file mode 100644 index 0000000000..6ddce55360 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@RequiredArgsConstructor +public class SessionStateStore { + private static final Logger LOG = LogManager.getLogger(); + + private final String indexName; + private final Client client; + + public SessionModel create(SessionModel session) { + try { + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(session.getSessionId().getSessionId()) + .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(session.getSeqNo()) + .setIfPrimaryTerm(session.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", session.getSessionId()); + return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + session.getSessionId(), + indexResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public Optional get(SessionId sid) { + try { + GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId()); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + SessionModel.fromXContent( + parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java deleted file mode 100644 index b5bf31a6ba..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import org.opensearch.core.xcontent.ToXContentObject; -import org.opensearch.core.xcontent.XContentParser; - -public abstract class StateModel implements ToXContentObject { - - public abstract String getId(); - - public abstract long getSeqNo(); - - public abstract long getPrimaryTerm(); - - public interface CopyBuilder { - T of(T copy, long seqNo, long primaryTerm); - } - - public interface StateCopyBuilder { - T of(T copy, S state, long seqNo, long primaryTerm); - } - - public interface FromXContent { - T fromXContent(XContentParser parser, long seqNo, long primaryTerm); - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java deleted file mode 100644 index bd72b17353..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import java.io.IOException; -import java.util.Locale; -import java.util.Optional; -import java.util.function.BiFunction; -import java.util.function.Function; -import lombok.RequiredArgsConstructor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; -import org.opensearch.action.update.UpdateResponse; -import org.opensearch.client.Client; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.execution.session.SessionModel; -import org.opensearch.sql.spark.execution.session.SessionState; -import org.opensearch.sql.spark.execution.statement.StatementModel; -import org.opensearch.sql.spark.execution.statement.StatementState; - -@RequiredArgsConstructor -public class StateStore { - private static final Logger LOG = LogManager.getLogger(); - - private final String indexName; - private final Client client; - - protected T create(T st, StateModel.CopyBuilder builder) { - try { - IndexRequest indexRequest = - new IndexRequest(indexName) - .id(st.getId()) - .source(st.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .setIfSeqNo(st.getSeqNo()) - .setIfPrimaryTerm(st.getPrimaryTerm()) - .create(true) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", st.getId()); - return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - st.getId(), - indexResponse.getResult().getLowercase())); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - protected Optional get(String sid, StateModel.FromXContent builder) { - try { - GetRequest getRequest = new GetRequest().index(indexName).id(sid); - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { - return Optional.empty(); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - protected T updateState( - T st, S state, StateModel.StateCopyBuilder builder) { - try { - T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); - UpdateRequest updateRequest = - new UpdateRequest() - .index(indexName) - .id(model.getId()) - .setIfSeqNo(model.getSeqNo()) - .setIfPrimaryTerm(model.getPrimaryTerm()) - .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .fetchSource(true) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - UpdateResponse updateResponse = client.update(updateRequest).actionGet(); - if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { - LOG.debug("Successfully update doc. id: {}", st.getId()); - return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed update doc. id: %s, error: %s", - st.getId(), - updateResponse.getResult().getLowercase())); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** Helper Functions */ - public static Function createStatement(StateStore stateStore) { - return (st) -> stateStore.create(st, StatementModel::copy); - } - - public static Function> getStatement(StateStore stateStore) { - return (docId) -> stateStore.get(docId, StatementModel::fromXContent); - } - - public static BiFunction updateStatementState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); - } - - public static Function createSession(StateStore stateStore) { - return (session) -> stateStore.create(session, SessionModel::of); - } - - public static Function> getSession(StateStore stateStore) { - return (docId) -> stateStore.get(docId, SessionModel::fromXContent); - } - - public static BiFunction updateSessionState( - StateStore stateStore) { - return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 488252d05a..53dc211ded 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -21,7 +20,7 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ @@ -31,13 +30,13 @@ public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; - private StateStore stateStore; + private SessionStateStore stateStore; @Before public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); + stateStore = new SessionStateStore(indexName, client()); createIndex(indexName); } @@ -51,7 +50,7 @@ public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(SessionId.newSessionId()) - .stateStore(stateStore) + .sessionStateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -75,7 +74,7 @@ public void openSessionFailedConflict() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .stateStore(stateStore) + .sessionStateStore(stateStore) .serverlessClient(emrsClient) .build(); session.open(new CreateSessionRequest(startJobRequest, "datasource")); @@ -83,7 +82,7 @@ public void openSessionFailedConflict() { InteractiveSession duplicateSession = InteractiveSession.builder() .sessionId(sessionId) - .stateStore(stateStore) + .sessionStateStore(stateStore) .serverlessClient(emrsClient) .build(); IllegalStateException exception = @@ -99,15 +98,15 @@ public void closeNotExistSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .stateStore(stateStore) + .sessionStateStore(stateStore) .serverlessClient(emrsClient) .build(); session.open(new CreateSessionRequest(startJobRequest, "datasource")); - client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())); IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); - assertEquals("session does not exist. " + sessionId, exception.getMessage()); + assertEquals("session not exist. " + sessionId, exception.getMessage()); emrsClient.cancelJobRunCalled(0); } @@ -143,9 +142,9 @@ public void sessionManagerGetSessionNotExist() { @RequiredArgsConstructor static class TestSession { private final Session session; - private final StateStore stateStore; + private final SessionStateStore stateStore; - public static TestSession testSession(Session session, StateStore stateStore) { + public static TestSession testSession(Session session, SessionStateStore stateStore) { return new TestSession(session, stateStore); } @@ -153,7 +152,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - getSession(stateStore).apply(session.getSessionModel().getId()); + stateStore.get(session.getSessionModel().getSessionId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -181,7 +180,7 @@ public TestSession close() { } } - public static class TestEMRServerlessClient implements EMRServerlessClient { + static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; private int cancelJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index 95b85613be..d35105f787 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -5,20 +5,29 @@ package org.opensearch.sql.spark.execution.session; +import static org.junit.jupiter.api.Assertions.*; + import org.junit.After; import org.junit.Before; +import org.mockito.MockMakers; +import org.mockito.MockSettings; +import org.mockito.Mockito; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; -import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.statestore.SessionStateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; class SessionManagerTest extends OpenSearchSingleNodeTestCase { private static final String indexName = "mockindex"; - private StateStore stateStore; + // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock. + private static final MockSettings mockSettings = + Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); + + private SessionStateStore stateStore; @Before public void setup() { - stateStore = new StateStore(indexName, client()); + stateStore = new SessionStateStore(indexName, client()); createIndex(indexName); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java deleted file mode 100644 index b7af1123ba..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statement; - -import static org.junit.Assert.assertThrows; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -class StatementStateTest { - @Test - public void invalidStatementState() { - IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> StatementState.fromString("invalid")); - Assertions.assertEquals("Invalid statement state: invalid", exception.getMessage()); - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java deleted file mode 100644 index 331955e14e..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ /dev/null @@ -1,356 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statement; - -import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; -import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; -import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; -import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; - -import java.util.HashMap; -import java.util.Optional; -import lombok.RequiredArgsConstructor; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.execution.session.CreateSessionRequest; -import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; -import org.opensearch.sql.spark.execution.session.Session; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.session.SessionState; -import org.opensearch.sql.spark.execution.statestore.StateStore; -import org.opensearch.sql.spark.rest.model.LangType; -import org.opensearch.test.OpenSearchSingleNodeTestCase; - -public class StatementTest extends OpenSearchSingleNodeTestCase { - - private static final String indexName = "mockindex"; - - private StartJobRequest startJobRequest; - private StateStore stateStore; - private InteractiveSessionTest.TestEMRServerlessClient emrsClient = - new InteractiveSessionTest.TestEMRServerlessClient(); - - @Before - public void setup() { - startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new StateStore(indexName, client()); - createIndex(indexName); - } - - @After - public void clean() { - client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); - } - - @Test - public void openThenCancelStatement() { - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(new StatementId("statementId")) - .langType(LangType.SQL) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - - // submit statement - TestStatement testStatement = testStatement(st, stateStore); - testStatement - .open() - .assertSessionState(WAITING) - .assertStatementId(new StatementId("statementId")); - - // close statement - testStatement.cancel().assertSessionState(CANCELLED); - } - - @Test - public void openFailedBecauseConflict() { - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(new StatementId("statementId")) - .langType(LangType.SQL) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - st.open(); - - // open statement with same statement id - Statement dupSt = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(new StatementId("statementId")) - .langType(LangType.SQL) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - IllegalStateException exception = assertThrows(IllegalStateException.class, dupSt::open); - assertEquals("statement already exist. statementId=statementId", exception.getMessage()); - } - - @Test - public void cancelNotExistStatement() { - StatementId stId = new StatementId("statementId"); - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - st.open(); - - client().delete(new DeleteRequest(indexName, stId.getId())); - - IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); - assertEquals( - String.format("cancel statement failed. no statement found. statement: %s.", stId), - exception.getMessage()); - } - - @Test - public void cancelFailedBecauseOfConflict() { - StatementId stId = new StatementId("statementId"); - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - st.open(); - - StatementModel running = - updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); - - assertEquals(StatementState.CANCELLED, running.getStatementState()); - - // cancel conflict - IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); - assertEquals( - String.format( - "cancel statement failed. current statementState: CANCELLED " + "statement: %s.", stId), - exception.getMessage()); - } - - @Test - public void cancelRunningStatementFailed() { - StatementId stId = new StatementId("statementId"); - Statement st = - Statement.builder() - .sessionId(new SessionId("sessionId")) - .applicationId("appId") - .jobId("jobId") - .statementId(stId) - .langType(LangType.SQL) - .query("query") - .queryId("statementId") - .stateStore(stateStore) - .build(); - st.open(); - - // update to running state - StatementModel model = st.getStatementModel(); - st.setStatementModel( - StatementModel.copyWithState( - st.getStatementModel(), - StatementState.RUNNING, - model.getSeqNo(), - model.getPrimaryTerm())); - - // cancel conflict - IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); - assertEquals( - String.format("can't cancel statement in waiting state. statement: %s.", stId), - exception.getMessage()); - } - - @Test - public void submitStatementInRunningSession() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - - // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); - - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); - assertFalse(statementId.getId().isEmpty()); - } - - @Test - public void submitStatementInNotStartedState() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); - assertFalse(statementId.getId().isEmpty()); - } - - @Test - public void failToSubmitStatementInDeadState() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); - - IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); - assertEquals( - "can't submit statement, session should not be in end state, current session state is:" - + " dead", - exception.getMessage()); - } - - @Test - public void failToSubmitStatementInFailState() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); - - IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); - assertEquals( - "can't submit statement, session should not be in end state, current session state is:" - + " fail", - exception.getMessage()); - } - - @Test - public void newStatementFieldAssert() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); - Optional statement = session.get(statementId); - - assertTrue(statement.isPresent()); - assertEquals(session.getSessionId(), statement.get().getSessionId()); - assertEquals("appId", statement.get().getApplicationId()); - assertEquals("jobId", statement.get().getJobId()); - assertEquals(statementId, statement.get().getStatementId()); - assertEquals(WAITING, statement.get().getStatementState()); - assertEquals(LangType.SQL, statement.get().getLangType()); - assertEquals("select 1", statement.get().getQuery()); - } - - @Test - public void failToSubmitStatementInDeletedSession() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - - // other's delete session - client() - .delete(new DeleteRequest(indexName, session.getSessionId().getSessionId())) - .actionGet(); - - IllegalStateException exception = - assertThrows( - IllegalStateException.class, - () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); - assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); - } - - @Test - public void getStatementSuccess() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); - StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); - - Optional statement = session.get(statementId); - assertTrue(statement.isPresent()); - assertEquals(WAITING, statement.get().getStatementState()); - assertEquals(statementId, statement.get().getStatementId()); - } - - @Test - public void getStatementNotExist() { - Session session = - new SessionManager(stateStore, emrsClient) - .createSession(new CreateSessionRequest(startJobRequest, "datasource")); - // App change state to running - updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); - - Optional statement = session.get(StatementId.newStatementId()); - assertFalse(statement.isPresent()); - } - - @RequiredArgsConstructor - static class TestStatement { - private final Statement st; - private final StateStore stateStore; - - public static TestStatement testStatement(Statement st, StateStore stateStore) { - return new TestStatement(st, stateStore); - } - - public TestStatement assertSessionState(StatementState expected) { - assertEquals(expected, st.getStatementModel().getStatementState()); - - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); - assertTrue(model.isPresent()); - assertEquals(expected, model.get().getStatementState()); - - return this; - } - - public TestStatement assertStatementId(StatementId expected) { - assertEquals(expected, st.getStatementModel().getStatementId()); - - Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); - assertTrue(model.isPresent()); - assertEquals(expected, model.get().getStatementId()); - return this; - } - - public TestStatement open() { - st.open(); - return this; - } - - public TestStatement cancel() { - st.cancel(); - return this; - } - } -} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java new file mode 100644 index 0000000000..9c779555d7 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.client.Client; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionModel; + +@ExtendWith(MockitoExtension.class) +class SessionStateStoreTest { + @Mock(answer = RETURNS_DEEP_STUBS) + private Client client; + + @Mock private IndexResponse indexResponse; + + @Test + public void createWithException() { + when(client.index(any()).actionGet()).thenReturn(indexResponse); + doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult(); + SessionModel sessionModel = + SessionModel.initInteractiveSession( + "appId", "jobId", SessionId.newSessionId(), "datasource"); + SessionStateStore sessionStateStore = new SessionStateStore("indexName", client); + + assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel)); + } +}