Skip to content

Commit

Permalink
SQL: Fix wrong appliance of StackOverflow limit for IN
Browse files Browse the repository at this point in the history
Fix grammar so that each element inside the list of values for IN
is a `valueExpression` and not a more generic `expression`. Also change
some names in the grammar so that they match the primary rule name
plus "Default". This helps so that the decrement of depth counts in
the Parser's `CircuitBreakerListener` works correctly.

For the list of values for `IN` don't count the
`PrimaryExpressionContext` as this is not visited on exit due to
the peculiarity in our gramamr with the `predicate` and `predicated`.

Fixes: #36592
  • Loading branch information
matriv committed Dec 18, 2018
1 parent 4103d3b commit 73a1e1c
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 16 deletions.
2 changes: 1 addition & 1 deletion x-pack/plugin/sql/src/main/antlr/SqlBase.g4
Expand Up @@ -186,7 +186,7 @@ predicated
// instead the property kind is used to differentiate
predicate
: NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression
| NOT? kind=IN '(' expression (',' expression)* ')'
| NOT? kind=IN '(' valueExpression (',' valueExpression)* ')'
| NOT? kind=IN '(' query ')'
| NOT? kind=LIKE pattern
| NOT? kind=RLIKE regex=string
Expand Down
Expand Up @@ -226,7 +226,7 @@ public Expression visitPredicated(PredicatedContext ctx) {
if (pCtx.query() != null) {
throw new ParsingException(loc, "IN query not supported yet");
}
e = new In(loc, exp, expressions(pCtx.expression()));
e = new In(loc, exp, expressions(pCtx.valueExpression()));
break;
case SqlBaseParser.LIKE:
e = new Like(loc, exp, visitPattern(pCtx.pattern()));
Expand Down
Expand Up @@ -3363,12 +3363,6 @@ public ValueExpressionContext valueExpression(int i) {
return getRuleContext(ValueExpressionContext.class,i);
}
public TerminalNode NOT() { return getToken(SqlBaseParser.NOT, 0); }
public List<ExpressionContext> expression() {
return getRuleContexts(ExpressionContext.class);
}
public ExpressionContext expression(int i) {
return getRuleContext(ExpressionContext.class,i);
}
public TerminalNode IN() { return getToken(SqlBaseParser.IN, 0); }
public QueryContext query() {
return getRuleContext(QueryContext.class,0);
Expand Down Expand Up @@ -3449,7 +3443,7 @@ public final PredicateContext predicate() throws RecognitionException {
setState(502);
match(T__0);
setState(503);
expression();
valueExpression(0);
setState(508);
_errHandler.sync(this);
_la = _input.LA(1);
Expand All @@ -3459,7 +3453,7 @@ public final PredicateContext predicate() throws RecognitionException {
setState(504);
match(T__2);
setState(505);
expression();
valueExpression(0);
}
}
setState(510);
Expand Down Expand Up @@ -6616,7 +6610,7 @@ private boolean valueExpression_sempred(ValueExpressionContext _localctx, int pr
"\u01f0\7\16\2\2\u01f0\u01f1\5<\37\2\u01f1\u01f2\7\n\2\2\u01f2\u01f3\5"+
"<\37\2\u01f3\u021b\3\2\2\2\u01f4\u01f6\7=\2\2\u01f5\u01f4\3\2\2\2\u01f5"+
"\u01f6\3\2\2\2\u01f6\u01f7\3\2\2\2\u01f7\u01f8\7-\2\2\u01f8\u01f9\7\3"+
"\2\2\u01f9\u01fe\5,\27\2\u01fa\u01fb\7\5\2\2\u01fb\u01fd\5,\27\2\u01fc"+
"\2\2\u01f9\u01fe\5<\37\2\u01fa\u01fb\7\5\2\2\u01fb\u01fd\5<\37\2\u01fc"+
"\u01fa\3\2\2\2\u01fd\u0200\3\2\2\2\u01fe\u01fc\3\2\2\2\u01fe\u01ff\3\2"+
"\2\2\u01ff\u0201\3\2\2\2\u0200\u01fe\3\2\2\2\u0201\u0202\7\4\2\2\u0202"+
"\u021b\3\2\2\2\u0203\u0205\7=\2\2\u0204\u0203\3\2\2\2\u0204\u0205\3\2"+
Expand Down
Expand Up @@ -26,6 +26,14 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryPrimaryDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryTermContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionDefaultContext;
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;

Expand Down Expand Up @@ -214,10 +222,26 @@ public void exitNonReserved(SqlBaseParser.NonReservedContext context) {
/**
* Used to catch large expressions that can lead to stack overflows
*/
private class CircuitBreakerListener extends SqlBaseBaseListener {
static class CircuitBreakerListener extends SqlBaseBaseListener {

private static final short MAX_RULE_DEPTH = 200;

/**
* Due to the structure of the grammar and our custom handling in {@link ExpressionBuilder}
* some expressions can exit with a different class than they entered:
* e.g.: ValueExpressionContext can exit as ValueExpressionDefaultContext
*/
private static final Map<String, String> ENTER_EXIT_RULE_MAPPING = new HashMap<>();

static {
ENTER_EXIT_RULE_MAPPING.put(StatementDefaultContext.class.getSimpleName(), StatementContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(QueryPrimaryDefaultContext.class.getSimpleName(), QueryTermContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(BooleanDefaultContext.class.getSimpleName(), BooleanExpressionContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(ValueExpressionDefaultContext.class.getSimpleName(), ValueExpressionContext.class.getSimpleName());
}

private boolean insideIn = false;

// Keep current depth for every rule visited.
// The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ...
// are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing
Expand All @@ -226,9 +250,18 @@ private class CircuitBreakerListener extends SqlBaseBaseListener {

@Override
public void enterEveryRule(ParserRuleContext ctx) {
if (inDetected(ctx)) {
insideIn = true;
}

// Skip PrimaryExpressionContext for IN as it's not visited on exit due to
// the grammar's peculiarity rule with "predicated" and "predicate".
// Also skip the Identifiers as they are "cheap".
if (ctx.getClass() != SqlBaseParser.UnquoteIdentifierContext.class &&
ctx.getClass() != SqlBaseParser.QuoteIdentifierContext.class &&
ctx.getClass() != SqlBaseParser.BackQuotedIdentifierContext.class) {
ctx.getClass() != SqlBaseParser.BackQuotedIdentifierContext.class &&
(insideIn == false || ctx.getClass() != SqlBaseParser.PrimaryExpressionContext.class)) {

int currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1);
if (currentDepth > MAX_RULE_DEPTH) {
throw new ParsingException(source(ctx), "SQL statement too large; " +
Expand All @@ -240,12 +273,35 @@ public void enterEveryRule(ParserRuleContext ctx) {

@Override
public void exitEveryRule(ParserRuleContext ctx) {
// Avoid having negative numbers
if (depthCounts.containsKey(ctx.getClass().getSimpleName())) {
depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 0, (short) -1);
if (inDetected(ctx)) {
insideIn = false;
}

decrementCounter(ctx);
super.exitEveryRule(ctx);
}

ObjectShortHashMap<String> depthCounts() {
return depthCounts;
}

private void decrementCounter(ParserRuleContext ctx) {
String className = ctx.getClass().getSimpleName();
String classNameToDecrement = ENTER_EXIT_RULE_MAPPING.getOrDefault(className, className);

// Avoid having negative numbers
if (depthCounts.containsKey(classNameToDecrement)) {
depthCounts.putOrAdd(classNameToDecrement, (short) 0, (short) -1);
}
}

private boolean inDetected(ParserRuleContext ctx) {
if (ctx.getParent() != null && ctx.getParent().getClass() == SqlBaseParser.PredicateContext.class) {
SqlBaseParser.PredicateContext pc = (SqlBaseParser.PredicateContext) ctx.getParent();
return pc.kind != null && pc.kind.getType() == SqlBaseParser.IN;
}
return false;
}
}

private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
Expand Down
Expand Up @@ -15,6 +15,13 @@
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MultiMatchQueryPredicate;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.StringQueryPredicate;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryPrimaryDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryTermContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionDefaultContext;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.plan.logical.OrderBy;
Expand Down Expand Up @@ -254,6 +261,40 @@ public void testLimitToPreventStackOverflowFromLargeComplexSubselectTree() {
e.getMessage());
}

public void testLimitStackOverflowForInAndLiteralsIsNotApplied() {
new SqlParser().createStatement("SELECT * FROM t WHERE a IN(" +
Joiner.on(",").join(nCopies(100_000, "a + b")) + ")");
}

public void testDecrementOfDepthCounter() {
SqlParser.CircuitBreakerListener cbl = new SqlParser.CircuitBreakerListener();
StatementContext sc = new StatementContext();
QueryTermContext qtc = new QueryTermContext();
ValueExpressionContext vec = new ValueExpressionContext();
BooleanExpressionContext bec = new BooleanExpressionContext();

cbl.enterEveryRule(sc);
cbl.enterEveryRule(sc);
cbl.enterEveryRule(qtc);
cbl.enterEveryRule(qtc);
cbl.enterEveryRule(qtc);
cbl.enterEveryRule(vec);
cbl.enterEveryRule(bec);
cbl.enterEveryRule(bec);

cbl.exitEveryRule(new StatementDefaultContext(sc));
cbl.exitEveryRule(new StatementDefaultContext(sc));
cbl.exitEveryRule(new QueryPrimaryDefaultContext(qtc));
cbl.exitEveryRule(new QueryPrimaryDefaultContext(qtc));
cbl.exitEveryRule(new ValueExpressionDefaultContext(vec));
cbl.exitEveryRule(new SqlBaseParser.BooleanDefaultContext(bec));

assertEquals(0, cbl.depthCounts().get(SqlBaseParser.StatementContext.class.getSimpleName()));
assertEquals(1, cbl.depthCounts().get(SqlBaseParser.QueryTermContext.class.getSimpleName()));
assertEquals(0, cbl.depthCounts().get(SqlBaseParser.ValueExpressionContext.class.getSimpleName()));
assertEquals(1, cbl.depthCounts().get(SqlBaseParser.BooleanExpressionContext.class.getSimpleName()));
}

private LogicalPlan parseStatement(String sql) {
return new SqlParser().createStatement(sql);
}
Expand Down

0 comments on commit 73a1e1c

Please sign in to comment.