Skip to content

Commit

Permalink
Fix SQLQueryUtils to extract multiple tables (#2784)
Browse files Browse the repository at this point in the history
* Fix SQLQueryUtils to extract multiple tables

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

* Improve test coverage

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>

---------

Signed-off-by: Tomoyuki Morita <moritato@amazon.com>
  • Loading branch information
ykmr1224 committed Jun 28, 2024
1 parent 49e2e0e commit 883cc7e
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 152 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.sql.spark.utils;

import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import lombok.Getter;
import lombok.experimental.UtilityClass;
Expand All @@ -18,6 +20,7 @@
import org.opensearch.sql.spark.antlr.parser.FlintSparkSqlExtensionsParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseLexer;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParser.IdentifierReferenceContext;
import org.opensearch.sql.spark.antlr.parser.SqlBaseParserBaseVisitor;
import org.opensearch.sql.spark.dispatcher.model.FlintIndexOptions;
import org.opensearch.sql.spark.dispatcher.model.FullyQualifiedTableName;
Expand All @@ -32,16 +35,15 @@
@UtilityClass
public class SQLQueryUtils {

// TODO Handle cases where the query has multiple table Names.
public static FullyQualifiedTableName extractFullyQualifiedTableName(String sqlQuery) {
public static List<FullyQualifiedTableName> extractFullyQualifiedTableNames(String sqlQuery) {
SqlBaseParser sqlBaseParser =
new SqlBaseParser(
new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery))));
sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener());
SqlBaseParser.StatementContext statement = sqlBaseParser.statement();
SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor();
statement.accept(sparkSqlTableNameVisitor);
return sparkSqlTableNameVisitor.getFullyQualifiedTableName();
return sparkSqlTableNameVisitor.getFullyQualifiedTableNames();
}

public static IndexQueryDetails extractIndexDetails(String sqlQuery) {
Expand Down Expand Up @@ -73,23 +75,21 @@ public static boolean isFlintExtensionQuery(String sqlQuery) {

public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor<Void> {

@Getter private FullyQualifiedTableName fullyQualifiedTableName;
@Getter private List<FullyQualifiedTableName> fullyQualifiedTableNames = new LinkedList<>();

public SparkSqlTableNameVisitor() {
this.fullyQualifiedTableName = new FullyQualifiedTableName();
}
public SparkSqlTableNameVisitor() {}

@Override
public Void visitTableName(SqlBaseParser.TableNameContext ctx) {
fullyQualifiedTableName = new FullyQualifiedTableName(ctx.getText());
return super.visitTableName(ctx);
public Void visitIdentifierReference(IdentifierReferenceContext ctx) {
fullyQualifiedTableNames.add(new FullyQualifiedTableName(ctx.getText()));
return super.visitIdentifierReference(ctx);
}

@Override
public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDropTable(ctx);
Expand All @@ -99,7 +99,7 @@ public Void visitDropTable(SqlBaseParser.DropTableContext ctx) {
public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitDescribeRelation(ctx);
Expand All @@ -110,7 +110,7 @@ public Void visitDescribeRelation(SqlBaseParser.DescribeRelationContext ctx) {
public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) {
for (ParseTree parseTree : ctx.children) {
if (parseTree instanceof SqlBaseParser.IdentifierReferenceContext) {
fullyQualifiedTableName = new FullyQualifiedTableName(parseTree.getText());
fullyQualifiedTableNames.add(new FullyQualifiedTableName(parseTree.getText()));
}
}
return super.visitCreateTableHeader(ctx);
Expand Down
Loading

0 comments on commit 883cc7e

Please sign in to comment.