Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,25 @@

package org.opensearch.sql.api.parser;

import static org.opensearch.sql.ast.dsl.AstDSL.existsSubquery;
import static org.opensearch.sql.ast.dsl.AstDSL.inSubquery;
import static org.opensearch.sql.ast.dsl.AstDSL.join;

import java.util.Optional;
import org.antlr.v4.runtime.tree.ParseTree;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.statement.Query;
import org.opensearch.sql.ast.statement.Statement;
import org.opensearch.sql.ast.tree.Join.JoinType;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.sql.antlr.SQLSyntaxParser;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExistsSubqueryExpressionAtomContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InSubqueryPredicateContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.JoinClauseContext;
import org.opensearch.sql.sql.parser.AstBuilder;
import org.opensearch.sql.sql.parser.AstExpressionBuilder;
import org.opensearch.sql.sql.parser.AstStatementBuilder;

/** SQL query parser that produces {@link UnresolvedPlan} using the V2 ANTLR grammar. */
Expand Down Expand Up @@ -52,6 +58,11 @@ private static class ExtendedAstBuilder extends AstBuilder {
super(query);
}

@Override
protected AstExpressionBuilder createExpressionBuilder() {
return new ExtendedAstExpressionBuilder();
}

@Override
public UnresolvedPlan visitJoinClause(JoinClauseContext ctx) {
JoinType joinType = toJoinType(ctx);
Expand All @@ -69,5 +80,27 @@ private JoinType toJoinType(JoinClauseContext ctx) {
default -> JoinType.INNER;
};
}

/**
* Expression builder with IN/EXISTS subquery support. Accesses the enclosing AstBuilder to
* visit subquery plan nodes. Must be created via {@link #createExpressionBuilder()} because the
* enclosing {@code this} reference is not available during {@code super()} construction.
*/
private class ExtendedAstExpressionBuilder extends AstExpressionBuilder {

@Override
public UnresolvedExpression visitInSubqueryPredicate(InSubqueryPredicateContext ctx) {
UnresolvedPlan subquery = ExtendedAstBuilder.this.visit(ctx.querySpecification());
UnresolvedExpression inExpr = inSubquery(subquery, visit(ctx.predicate()));
return (ctx.NOT() != null) ? new Not(inExpr) : inExpr;
}

@Override
public UnresolvedExpression visitExistsSubqueryExpressionAtom(
ExistsSubqueryExpressionAtomContext ctx) {
UnresolvedPlan subquery = ExtendedAstBuilder.this.visit(ctx.querySpecification());
return existsSubquery(subquery);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,80 @@ public void testJoinWithFilterAndOrderBy() {
LogicalTableScan(table=[[catalog, departments]])
""");
}

@Test
public void testInSubquery() {
givenQuery(
"""
SELECT name FROM catalog.employees
WHERE age IN (SELECT age FROM catalog.departments WHERE dept_name = 'Engineering')
""")
.assertPlan(
"""
LogicalProject(name=[$1])
LogicalFilter(condition=[IN($2, {
LogicalProject(age=[$cor0.age])
LogicalFilter(condition=[=($1, 'Engineering')])
LogicalTableScan(table=[[catalog, departments]])
})], variablesSet=[[$cor0]])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testExistsSubquery() {
givenQuery(
"""
SELECT name FROM catalog.employees
WHERE EXISTS (SELECT 1 FROM catalog.departments WHERE dept_id = age)
""")
.assertPlan(
"""
LogicalProject(name=[$1])
LogicalFilter(condition=[EXISTS({
LogicalProject(1=[1])
LogicalFilter(condition=[=($0, $cor0.age)])
LogicalTableScan(table=[[catalog, departments]])
})], variablesSet=[[$cor0]])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testNotInSubquery() {
givenQuery(
"""
SELECT name FROM catalog.employees
WHERE age NOT IN (SELECT age FROM catalog.departments WHERE dept_name = 'Engineering')
""")
.assertPlan(
"""
LogicalProject(name=[$1])
LogicalFilter(condition=[NOT(IN($2, {
LogicalProject(age=[$cor0.age])
LogicalFilter(condition=[=($1, 'Engineering')])
LogicalTableScan(table=[[catalog, departments]])
}))], variablesSet=[[$cor0]])
LogicalTableScan(table=[[catalog, employees]])
""");
}

@Test
public void testNotExistsSubquery() {
givenQuery(
"""
SELECT name FROM catalog.employees
WHERE NOT EXISTS (SELECT 1 FROM catalog.departments WHERE dept_id = age)
""")
.assertPlan(
"""
LogicalProject(name=[$1])
LogicalFilter(condition=[NOT(EXISTS({
LogicalProject(1=[1])
LogicalFilter(condition=[=($0, $cor0.age)])
LogicalTableScan(table=[[catalog, departments]])
}))], variablesSet=[[$cor0]])
LogicalTableScan(table=[[catalog, employees]])
""");
}
}
10 changes: 10 additions & 0 deletions core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import org.opensearch.sql.ast.expression.When;
import org.opensearch.sql.ast.expression.WindowFunction;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.AppendPipe;
import org.opensearch.sql.ast.tree.Bin;
Expand Down Expand Up @@ -771,4 +773,12 @@ public static UnresolvedPlan join(
Optional.empty(),
Argument.ArgumentMap.empty());
}

public static InSubquery inSubquery(UnresolvedPlan query, UnresolvedExpression... values) {
return new InSubquery(List.of(values), query);
}

public static ExistsSubquery existsSubquery(UnresolvedPlan query) {
return new ExistsSubquery(query);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ private List<RexNode> expandProjectFields(
.filter(addedFields::add)
.forEach(field -> expandedFields.add(context.relBuilder.field(field)));
}
case Alias alias -> {
expandedFields.add(rexVisitor.analyze(alias, context));
}
default ->
throw new IllegalStateException(
"Unexpected expression type in project list: " + expr.getClass().getSimpleName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,13 @@ public void testLeftJoinFallback() throws IOException {
.formatted(TEST_INDEX_PEOPLE, TEST_INDEX_DOG));
verifyDataRows(result, rows("Daenerys", "rex"));
}

@Test
public void testInSubqueryFallback() throws IOException {
JSONObject result =
executeQuery(
"SELECT a.firstname FROM %s a WHERE a.firstname IN (SELECT holdersName FROM %s)"
.formatted(TEST_INDEX_PEOPLE, TEST_INDEX_DOG));
verifyDataRows(result, rows("Daenerys"), rows("Hattie"));
}
}
2 changes: 2 additions & 0 deletions sql/src/main/antlr/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ predicate
| left = predicate NOT? LIKE right = predicate # likePredicate
| left = predicate REGEXP right = predicate # regexpPredicate
| predicate NOT? IN '(' expressions ')' # inPredicate
| predicate NOT? IN '(' querySpecification ')' # inSubqueryPredicate
;

expressions
Expand All @@ -333,6 +334,7 @@ expressionAtom
| columnName # fullColumnNameExpressionAtom
| functionCall # functionCallExpressionAtom
| LR_BRACKET expression RR_BRACKET # nestedExpressionAtom
| EXISTS LR_BRACKET querySpecification RR_BRACKET # existsSubqueryExpressionAtom
| left = expressionAtom mathOperator = (STAR | SLASH | MODULE) right = expressionAtom # mathExpressionAtom
| left = expressionAtom mathOperator = (PLUS | MINUS) right = expressionAtom # mathExpressionAtom
;
Expand Down
14 changes: 11 additions & 3 deletions sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Collections;
import java.util.Locale;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.antlr.v4.runtime.tree.ParseTree;
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.AllFields;
Expand All @@ -50,10 +49,9 @@
import org.opensearch.sql.sql.parser.context.ParsingContext;

/** Abstract syntax tree (AST) builder. */
@RequiredArgsConstructor
public class AstBuilder extends OpenSearchSQLParserBaseVisitor<UnresolvedPlan> {

private final AstExpressionBuilder expressionBuilder = new AstExpressionBuilder();
private final AstExpressionBuilder expressionBuilder;

/** Parsing context stack that contains context for current query parsing. */
private final ParsingContext context = new ParsingContext();
Expand All @@ -64,6 +62,11 @@ public class AstBuilder extends OpenSearchSQLParserBaseVisitor<UnresolvedPlan> {
*/
private final String query;

public AstBuilder(String query) {
this.query = query;
this.expressionBuilder = createExpressionBuilder();
}

@Override
public UnresolvedPlan visitShowStatement(OpenSearchSQLParser.ShowStatementContext ctx) {
final UnresolvedExpression tableFilter = visitAstExpression(ctx.tableFilter());
Expand Down Expand Up @@ -279,6 +282,11 @@ protected UnresolvedExpression visitAstExpression(ParseTree tree) {
return expressionBuilder.visit(tree);
}

/** Override to provide a custom expression builder (e.g., with subquery support). */
protected AstExpressionBuilder createExpressionBuilder() {
return new AstExpressionBuilder();
}

private UnresolvedExpression visitSelectItem(SelectElementContext ctx) {
String name = StringUtils.unquoteIdentifier(getTextInQuery(ctx.expression(), query));
UnresolvedExpression expr = visitAstExpression(ctx.expression());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DateLiteralContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DistinctCountFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExistsSubqueryExpressionAtomContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ExtractFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FilterClauseContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FilteredAggregationFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FunctionArgContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.GetFormatFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.HighlightFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InSubqueryPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IsNullPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext;
Expand Down Expand Up @@ -82,6 +84,7 @@
import org.opensearch.sql.ast.dsl.AstDSL;
import org.opensearch.sql.ast.expression.*;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser;
Expand Down Expand Up @@ -668,4 +671,17 @@ private List<UnresolvedExpression> getExtractFunctionArguments(ExtractFunctionCa
visitFunctionArg(ctx.extractFunction().functionArg()));
return args;
}

@Override
public UnresolvedExpression visitInSubqueryPredicate(InSubqueryPredicateContext ctx) {
throw new SyntaxCheckException(
"IN subquery is not supported in the V2 SQL engine. Falling back to legacy engine.");
}

@Override
public UnresolvedExpression visitExistsSubqueryExpressionAtom(
ExistsSubqueryExpressionAtomContext ctx) {
throw new SyntaxCheckException(
"EXISTS subquery is not supported in the V2 SQL engine. Falling back to legacy engine.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -756,4 +756,18 @@ public UnresolvedPlan visitJoinClause(OpenSearchSQLParser.JoinClauseContext ctx)
};
assertNotNull(new SQLSyntaxParser().parse(query).accept(builder));
}

@Test
public void in_subquery_throws_syntax_check_exception() {
assertThrows(
SyntaxCheckException.class,
() -> buildAST("SELECT * FROM t WHERE age IN (SELECT age FROM t2)"));
}

@Test
public void exists_subquery_throws_syntax_check_exception() {
assertThrows(
SyntaxCheckException.class,
() -> buildAST("SELECT * FROM t WHERE EXISTS (SELECT 1 FROM t2)"));
}
}
Loading