Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Support ordinal aliases in GROUP and ORDER BY clauses #248

Merged
merged 15 commits into from
Oct 24, 2019
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import com.amazon.opendistroforelasticsearch.sql.rewriter.matchtoterm.TermFieldRewriter;
import com.amazon.opendistroforelasticsearch.sql.rewriter.matchtoterm.TermFieldRewriter.TermRewriterFilter;
import com.amazon.opendistroforelasticsearch.sql.rewriter.nestedfield.NestedFieldRewriter;
import com.amazon.opendistroforelasticsearch.sql.rewriter.ordinal.OrdinalRewriterRule;
import com.amazon.opendistroforelasticsearch.sql.rewriter.parent.SQLExprParentSetterRule;
import com.amazon.opendistroforelasticsearch.sql.rewriter.subquery.SubQueryRewriteRule;
import org.elasticsearch.client.Client;
Expand Down Expand Up @@ -79,6 +80,7 @@ public static QueryAction create(Client client, String sql) throws SqlParseExcep

RewriteRuleExecutor<SQLQueryExpr> ruleExecutor = RewriteRuleExecutor.builder()
.withRule(new SQLExprParentSetterRule())
.withRule(new OrdinalRewriterRule(sql))
.withRule(new UnquoteIdentifierRule())
.withRule(new TableAliasPrefixRemoveRule())
.withRule(new SubQueryRewriteRule())
Expand Down Expand Up @@ -175,9 +177,8 @@ private static SQLExpr toSqlExpr(String sql) {
SQLExpr expr = parser.expr();

if (parser.getLexer().token() != Token.EOF) {
throw new ParserException("illegal sql expr : " + sql);
throw new ParserException("Illegal SQL expression : " + sql);
}

return expr;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.sql.rewriter.ordinal;

import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.alibaba.druid.sql.ast.expr.SQLIntegerExpr;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectOrderByItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlSelectGroupByExpr;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;
import com.alibaba.druid.sql.parser.SQLExprParser;
import com.amazon.opendistroforelasticsearch.sql.parser.ElasticSqlExprParser;
import com.amazon.opendistroforelasticsearch.sql.rewriter.RewriteRule;
import com.amazon.opendistroforelasticsearch.sql.rewriter.matchtoterm.VerificationException;

import java.util.List;

/**
* Rewrite rule for changing ordinal alias in order by and group by to actual select field.
*/
public class OrdinalRewriterRule implements RewriteRule<SQLQueryExpr> {

private final String sql;

public OrdinalRewriterRule(String sql) {
this.sql = sql;
}

@Override
public boolean match(SQLQueryExpr root) {
SQLSelectQuery sqlSelectQuery = root.getSubQuery().getQuery();
if (!(sqlSelectQuery instanceof MySqlSelectQueryBlock)) {
// it could be SQLUnionQuery
return false;
}

MySqlSelectQueryBlock query = (MySqlSelectQueryBlock) sqlSelectQuery;

if (!hasGroupByWithOrdinals(query) && !hasOrderByWithOrdinals(query)) {
Copy link
Contributor

@galkk galkk Oct 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we've supposedly have the same check twice, could you add more comments explaining what's going on? From the code it looks like this checks if we have ordinals, and the code below does the same, but in different way, on parsed raw sql query. Could only one check work? Why we need to parse sql again?

In general, could you add better description of algorithm into the PR or into javadoc, with examples. I find the intention of the code a bit hard to follow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments to help understand the code better. Removed the redundant check.

return false;
}
return true;
}

@Override
public void rewrite(SQLQueryExpr root) {

// we cannot clone SQLSelectItem, so we need similar objects to assign to GroupBy and OrderBy items
SQLQueryExpr sqlExprGroupCopy = toSqlExpr();
SQLQueryExpr sqlExprOrderCopy = toSqlExpr();

changeOrdinalAliasInGroupAndOrderBy(root, sqlExprGroupCopy, sqlExprOrderCopy);
}

private void changeOrdinalAliasInGroupAndOrderBy(SQLQueryExpr root,
SQLQueryExpr exprGroup,
SQLQueryExpr exprOrder) {
root.accept(new MySqlASTVisitorAdapter() {

private String groupException = "Invalid ordinal [%s] specified in [GROUP BY %s]";
private String orderException = "Invalid ordinal [%s] specified in [ORDER BY %s]";

private List<SQLSelectItem> groupSelectList = ((MySqlSelectQueryBlock) exprGroup.getSubQuery().getQuery())
.getSelectList();

private List<SQLSelectItem> orderSelectList = ((MySqlSelectQueryBlock) exprOrder.getSubQuery().getQuery())
.getSelectList();

@Override
public boolean visit(MySqlSelectGroupByExpr groupByExpr) {
SQLExpr expr = groupByExpr.getExpr();
if (expr instanceof SQLIntegerExpr) {
Integer ordinalValue = ((SQLIntegerExpr) expr).getNumber().intValue();
SQLExpr newExpr = checkAndGet(groupSelectList, ordinalValue, groupException);
groupByExpr.setExpr(newExpr);
newExpr.setParent(groupByExpr);
}
return false;
}

@Override
public boolean visit(SQLSelectOrderByItem orderByItem) {
SQLExpr expr = orderByItem.getExpr();
Integer ordinalValue;

if (expr instanceof SQLIntegerExpr) {
ordinalValue = ((SQLIntegerExpr) expr).getNumber().intValue();
SQLExpr newExpr = checkAndGet(orderSelectList, ordinalValue, orderException);
orderByItem.setExpr(newExpr);
newExpr.setParent(orderByItem);
} else if (expr instanceof SQLBinaryOpExpr
&& ((SQLBinaryOpExpr) expr).getLeft() instanceof SQLIntegerExpr) {
// support ORDER BY IS NULL/NOT NULL
SQLBinaryOpExpr binaryOpExpr = (SQLBinaryOpExpr) expr;
SQLIntegerExpr integerExpr = (SQLIntegerExpr) binaryOpExpr.getLeft();

ordinalValue = integerExpr.getNumber().intValue();
SQLExpr newExpr = checkAndGet(orderSelectList, ordinalValue, orderException);
binaryOpExpr.setLeft(newExpr);
newExpr.setParent(binaryOpExpr);
}

return false;
}
});
}

private SQLExpr checkAndGet(List<SQLSelectItem> selectList, Integer ordinal, String exception) {
if (ordinal > selectList.size()) {
throw new VerificationException(String.format(exception, ordinal, ordinal));
}

return selectList.get(ordinal-1).getExpr();
}

private boolean hasGroupByWithOrdinals(MySqlSelectQueryBlock query) {
if (query.getGroupBy() == null) {
return false;
} else if (query.getGroupBy().getItems().isEmpty()){
return false;
}

return query.getGroupBy().getItems().stream().anyMatch(x ->
x instanceof MySqlSelectGroupByExpr && ((MySqlSelectGroupByExpr) x).getExpr() instanceof SQLIntegerExpr
);
}

private boolean hasOrderByWithOrdinals(MySqlSelectQueryBlock query) {
if (query.getOrderBy() == null) {
return false;
} else if (query.getOrderBy().getItems().isEmpty()){
return false;
}

/**
* The second condition checks valid AST that meets ORDER BY IS NULL/NOT NULL condition
*
* SQLSelectOrderByItem
* |
* SQLBinaryOpExpr (Is || IsNot)
* / \
* SQLIdentifierExpr SQLNullExpr
*/
return query.getOrderBy().getItems().stream().anyMatch(x ->
x.getExpr() instanceof SQLIntegerExpr
|| (
x.getExpr() instanceof SQLBinaryOpExpr
&& ((SQLBinaryOpExpr) x.getExpr()).getLeft() instanceof SQLIntegerExpr
)
);
}

private SQLQueryExpr toSqlExpr() {
SQLExprParser parser = new ElasticSqlExprParser(sql);
SQLExpr expr = parser.expr();
return (SQLQueryExpr) expr;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amazon.opendistroforelasticsearch.sql.unittest.rewriter.ordinal;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.expr.SQLQueryExpr;

import com.amazon.opendistroforelasticsearch.sql.rewriter.matchtoterm.VerificationException;
import com.amazon.opendistroforelasticsearch.sql.rewriter.ordinal.OrdinalRewriterRule;
import com.amazon.opendistroforelasticsearch.sql.util.SqlParserUtils;

import org.junit.Assert;
import org.junit.Test;

import static org.hamcrest.Matchers.containsString;

/**
* Test cases for ordinal aliases in GROUP BY and ORDER BY
*/

public class OrdinalRewriterRuleTest {

abbashus marked this conversation as resolved.
Show resolved Hide resolved
@Test
public void ordinalInGroupByShouldMatch() {
query("SELECT lastname FROM bank GROUP BY 1").shouldMatchRule();
}

@Test
public void ordinalInOrderByShouldMatch() {
query("SELECT lastname FROM bank ORDER BY 1").shouldMatchRule();
}


@Test
public void ordinalInGroupAndOrderByShouldMatch() {
query("SELECT lastname, age FROM bank GROUP BY 2, 1 ORDER BY 1").shouldMatchRule();
}

@Test
public void noOrdinalInGroupByShouldNotMatch() {
query("SELECT lastname FROM bank GROUP BY lastname").shouldNotMatchRule();
}

@Test
public void noOrdinalInOrderByShouldNotMatch() {
query("SELECT lastname, age FROM bank ORDER BY age").shouldNotMatchRule();
}

@Test
public void noOrdinalInGroupAndOrderByShouldNotMatch() {
query("SELECT lastname, age FROM bank GROUP BY lastname, age ORDER BY age").shouldNotMatchRule();
}

@Test
public void simpleGroupByOrdinal() {
query("SELECT lastname FROM bank GROUP BY 1"
).shouldBeAfterRewrite("SELECT lastname FROM bank GROUP BY lastname");
}

@Test
public void multipleGroupByOrdinal() {
query("SELECT lastname, age FROM bank GROUP BY 1, 2 "
).shouldBeAfterRewrite("SELECT lastname, age FROM bank GROUP BY lastname, age");

query("SELECT lastname, age FROM bank GROUP BY 2, 1"
).shouldBeAfterRewrite("SELECT lastname, age FROM bank GROUP BY age, lastname");

query("SELECT lastname, age, firstname FROM bank GROUP BY 2, firstname, 1"
).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank GROUP BY age, firstname, lastname");

query("SELECT lastname, age, firstname FROM bank GROUP BY 2, something, 1"
).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank GROUP BY age, something, lastname");
}


@Test
public void simpleOrderByOrdinal() {
query("SELECT lastname FROM bank ORDER BY 1"
).shouldBeAfterRewrite("SELECT lastname FROM bank ORDER BY lastname");
}

@Test
public void multipleOrderByOrdinal() {
query("SELECT lastname, age FROM bank ORDER BY 1, 2 "
).shouldBeAfterRewrite("SELECT lastname, age FROM bank ORDER BY lastname, age");

query("SELECT lastname, age FROM bank ORDER BY 2, 1"
).shouldBeAfterRewrite("SELECT lastname, age FROM bank ORDER BY age, lastname");

query("SELECT lastname, age, firstname FROM bank ORDER BY 2, firstname, 1"
).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank ORDER BY age, firstname, lastname");

query("SELECT lastname, age, firstname FROM bank ORDER BY 2, department, 1"
).shouldBeAfterRewrite("SELECT lastname, age, firstname FROM bank ORDER BY age, department, lastname");
}


// TODO: Some more Tests

private QueryAssertion query(String sql) {
return new QueryAssertion(sql);
}
private static class QueryAssertion {

private OrdinalRewriterRule rule;
private SQLQueryExpr expr;

QueryAssertion(String sql) {
this.expr = SqlParserUtils.parse(sql);
this.rule = new OrdinalRewriterRule(sql);
}

void shouldBeAfterRewrite(String expected) {
shouldMatchRule();
rule.rewrite(expr);
Assert.assertEquals(
SQLUtils.toMySqlString(SqlParserUtils.parse(expected)),
SQLUtils.toMySqlString(expr)
);
}

void shouldMatchRule() {
Assert.assertTrue(match());
}

void shouldNotMatchRule() {
Assert.assertFalse(match());
}

void shouldThrowException(int ordinal) {
try {
shouldMatchRule();
rule.rewrite(expr);
Assert.fail("Expected VerificationException, but none was thrown");
} catch (VerificationException e) {
Assert.assertThat(e.getMessage(), containsString("Invalid ordinal ["+ ordinal +"] specified in"));
}
}

private boolean match() {
return rule.match(expr);
}
}
}