Permalink
Browse files

IMPALA-2081: Add PERCENT_RANK, NTILE, CUME_DIST analytic window funct…

…ions

These functions are implemented as rewrites in the analysis stage. They are rewritten as
different arithmetic expressions and make use of the existing analytic functions such as
'rank', 'count' and 'row_number' to compute the final results.

TODO: IMPALA-2171: NTILE() currently takes only constant expressions. We need to modify
it to take non-constant expressions as well in a future patch.

Change-Id: I8773df8ceefff27ab66a41169dc4ac0927465191
Reviewed-on: http://gerrit.cloudera.org:8080/584
Tested-by: Internal Jenkins
Reviewed-by: Henry Robinson <henry@cloudera.com>
  • Loading branch information...
smukil authored and henryr committed Jul 28, 2015
1 parent 6837cde commit 1e65d1a642f935eddb38d7362e04404f966cb1c8
@@ -27,6 +27,7 @@
import com.cloudera.impala.catalog.AggregateFunction;
import com.cloudera.impala.catalog.Function;
import com.cloudera.impala.catalog.Type;
+import com.cloudera.impala.catalog.ScalarType;
import com.cloudera.impala.common.AnalysisException;
import com.cloudera.impala.common.InternalException;
import com.cloudera.impala.common.TreeNode;
@@ -84,6 +85,9 @@
private static String ROWNUMBER = "row_number";
private static String MIN = "min";
private static String MAX = "max";
+ private static String PERCENT_RANK = "percent_rank";
+ private static String CUME_DIST = "cume_dist";
+ private static String NTILE = "ntile";
// Internal function used to implement FIRST_VALUE with a window rewrite and
// additional null handling in the backend.
@@ -181,31 +185,176 @@ public String debugString() {
protected void toThrift(TExprNode msg) {
}
- public static boolean isAnalyticFn(Function fn) {
+ private static boolean isAnalyticFn(Function fn) {
return fn instanceof AggregateFunction
&& ((AggregateFunction) fn).isAnalyticFn();
}
+ private static boolean isAnalyticFn(Function fn, String fnName) {
+ return isAnalyticFn(fn) && fn.functionName().equals(fnName);
+ }
+
public static boolean isAggregateFn(Function fn) {
return fn instanceof AggregateFunction
&& ((AggregateFunction) fn).isAggregateFn();
}
+ public static boolean isPercentRankFn(Function fn) {
+ return isAnalyticFn(fn, PERCENT_RANK);
+ }
+
+ public static boolean isCumeDistFn(Function fn) {
+ return isAnalyticFn(fn, CUME_DIST);
+ }
+
+ public static boolean isNtileFn(Function fn) {
+ return isAnalyticFn(fn, NTILE);
+ }
+
static private boolean isOffsetFn(Function fn) {
- if (!isAnalyticFn(fn)) return false;
- return fn.functionName().equals(LEAD) || fn.functionName().equals(LAG);
+ return isAnalyticFn(fn, LEAD) || isAnalyticFn(fn, LAG);
}
static private boolean isMinMax(Function fn) {
- if (!isAnalyticFn(fn)) return false;
- return fn.functionName().equals(MIN) || fn.functionName().equals(MAX);
+ return isAnalyticFn(fn, MIN) || isAnalyticFn(fn, MAX);
}
static private boolean isRankingFn(Function fn) {
- if (!isAnalyticFn(fn)) return false;
- return fn.functionName().equals(RANK)
- || fn.functionName().equals(DENSERANK)
- || fn.functionName().equals(ROWNUMBER);
+ return isAnalyticFn(fn, RANK) || isAnalyticFn(fn, DENSERANK) ||
+ isAnalyticFn(fn, ROWNUMBER);
+ }
+
+ /**
+ * Rewrite the following analytic functions:
+ * percent_rank(), cume_dist() and ntile()
+ *
+ * Returns a new Expr if the analytic expr is rewritten, returns null if it's not one
+ * that we want to rewrite.
+ */
+ public static Expr rewrite(AnalyticExpr analyticExpr) {
+ Function fn = analyticExpr.getFnCall().getFn();
+ if (AnalyticExpr.isPercentRankFn(fn)) {
+ return createPercentRank(analyticExpr);
+ } else if (AnalyticExpr.isCumeDistFn(fn)) {
+ return createCumeDist(analyticExpr);
+ } else if (AnalyticExpr.isNtileFn(fn)) {
+ return createNtile(analyticExpr);
+ }
+ return null;
+ }
+
+ /**
+ * Rewrite percent_rank() to the following:
+ *
+ * percent_rank() over([partition by clause] order by clause)
+ * = (Rank - 1)/(Count - 1)
+ * where,
+ * Rank = rank() over([partition by clause] order by clause)
+ * Count = count() over([partition by clause])
+ */
+ private static Expr createPercentRank(AnalyticExpr analyticExpr) {
+ Preconditions.checkState(
+ AnalyticExpr.isPercentRankFn(analyticExpr.getFnCall().getFn()));
+ AnalyticExpr rankExpr =
+ create("rank", analyticExpr, true, false);
+ AnalyticExpr countExpr =
+ create("count", analyticExpr, false, false);
+ NumericLiteral one = new NumericLiteral(BigInteger.valueOf(1), ScalarType.BIGINT);
+ ArithmeticExpr arithmeticRewrite =
+ new ArithmeticExpr(ArithmeticExpr.Operator.DIVIDE,
+ new ArithmeticExpr(ArithmeticExpr.Operator.SUBTRACT, rankExpr, one),
+ new ArithmeticExpr(ArithmeticExpr.Operator.SUBTRACT, countExpr, one));
+ return arithmeticRewrite;
+ }
+
+ /**
+ * Rewrite cume_dist() to the following:
+ *
+ * cume_dist() over([partition by clause] order by clause)
+ * = ((Count - Rank) + 1)/Count
+ * where,
+ * Rank = rank() over([partition by clause] order by clause DESC)
+ * Count = count() over([partition by clause])
+ */
+ private static Expr createCumeDist(AnalyticExpr analyticExpr) {
+ Preconditions.checkState(
+ AnalyticExpr.isCumeDistFn(analyticExpr.getFnCall().getFn()));
+ AnalyticExpr rankExpr =
+ create("rank", analyticExpr, true, true);
+ AnalyticExpr countExpr =
+ create("count", analyticExpr, false, false);
+ NumericLiteral one = new NumericLiteral(BigInteger.valueOf(1), ScalarType.BIGINT);
+ ArithmeticExpr arithmeticRewrite =
+ new ArithmeticExpr(ArithmeticExpr.Operator.DIVIDE,
+ new ArithmeticExpr(ArithmeticExpr.Operator.ADD,
+ new ArithmeticExpr(ArithmeticExpr.Operator.SUBTRACT, countExpr, rankExpr),
+ one),
+ countExpr);
+ return arithmeticRewrite;
+ }
+
+ /**
+ * Rewrite ntile() to the following:
+ *
+ * ntile(B) over([partition by clause] order by clause)
+ * = floor(min(Count, B) * (RowNumber - 1)/Count) + 1
+ * where,
+ * RowNumber = row_number() over([partition by clause] order by clause)
+ * Count = count() over([partition by clause])
+ */
+ private static Expr createNtile(AnalyticExpr analyticExpr) {
+ Preconditions.checkState(
+ AnalyticExpr.isNtileFn(analyticExpr.getFnCall().getFn()));
+ Expr bucketExpr = analyticExpr.getChild(0);
+ AnalyticExpr rowNumExpr =
+ create("row_number", analyticExpr, true, false);
+ AnalyticExpr countExpr =
+ create("count", analyticExpr, false, false);
+
+ List<Expr> ifParams = Lists.newArrayList();
+ ifParams.add(
+ new BinaryPredicate(BinaryPredicate.Operator.LT, bucketExpr, countExpr));
+ ifParams.add(bucketExpr);
+ ifParams.add(countExpr);
+
+ NumericLiteral one = new NumericLiteral(BigInteger.valueOf(1), ScalarType.BIGINT);
+ ArithmeticExpr minMultiplyRowMinusOne =
+ new ArithmeticExpr(ArithmeticExpr.Operator.MULTIPLY,
+ new ArithmeticExpr(ArithmeticExpr.Operator.SUBTRACT, rowNumExpr, one),
+ new FunctionCallExpr("if", ifParams));
+ ArithmeticExpr divideAddOne =
+ new ArithmeticExpr(ArithmeticExpr.Operator.ADD,
+ new ArithmeticExpr(ArithmeticExpr.Operator.INT_DIVIDE,
+ minMultiplyRowMinusOne, countExpr),
+ one);
+ return divideAddOne;
+ }
+
+ /**
+ * Create a new Analytic Expr and associate it with a new function.
+ * Takes a reference analytic expression and clones the partition expressions and the
+ * order by expressions if 'copyOrderBy' is set and optionally reverses it if
+ * 'reverseOrderBy' is set. The new function that it will be associated with is
+ * specified by fnName.
+ */
+ private static AnalyticExpr create(String fnName,
+ AnalyticExpr referenceExpr, boolean copyOrderBy, boolean reverseOrderBy) {
+ FunctionCallExpr fnExpr = new FunctionCallExpr(fnName, new ArrayList<Expr>());
+ fnExpr.setIsAnalyticFnCall(true);
+ List<OrderByElement> orderByElements = null;
+ if (copyOrderBy) {
+ if (reverseOrderBy) {
+ orderByElements = OrderByElement.reverse(referenceExpr.getOrderByElements());
+ } else {
+ orderByElements = Lists.newArrayList();
+ for (OrderByElement elem: referenceExpr.getOrderByElements()) {
+ orderByElements.add(elem.clone());
+ }
+ }
+ }
+ AnalyticExpr analyticExpr = new AnalyticExpr(fnExpr,
+ Expr.cloneList(referenceExpr.getPartitionExprs()), orderByElements, null);
+ return analyticExpr;
}
/**
@@ -320,6 +469,23 @@ public void analyze(Analyzer analyzer) throws AnalysisException {
}
}
}
+ if (isNtileFn(fn)) {
+ // TODO: IMPALA-2171:Remove this when ntile() can handle a non-constant argument.
+ if (!getFnCall().getChild(0).isConstant()) {
+ throw new AnalysisException("NTILE() requires a constant argument");
+ }
+ // Check if argument value is zero or negative and throw an exception if found.
+ try {
+ TColumnValue bucketValue =
+ FeSupport.EvalConstExpr(getFnCall().getChild(0), analyzer.getQueryCtx());
+ Long arg = bucketValue.getLong_val();
+ if (arg <= 0) {
+ throw new AnalysisException("NTILE() requires a positive argument: " + arg);
+ }
+ } catch (InternalException e) {
+ throw new AnalysisException(e.toString());
+ }
+ }
}
if (window_ != null) {
@@ -257,7 +257,7 @@ public void analyze(Analyzer analyzer) throws AnalysisException {
createSortInfo(analyzer);
analyzeAggregation(analyzer);
- analyzeAnalytics(analyzer);
+ createAnalyticInfo(analyzer);
if (evaluateOrderBy_) createSortTupleInfo(analyzer);
// Remember the SQL string before inline-view expression substitution.
@@ -768,7 +768,7 @@ private void createAggInfo(ArrayList<Expr> groupingExprs,
* If the select list contains AnalyticExprs, create AnalyticInfo and substitute
* AnalyticExprs using the AnalyticInfo's smap.
*/
- private void analyzeAnalytics(Analyzer analyzer)
+ private void createAnalyticInfo(Analyzer analyzer)
throws AnalysisException {
// collect AnalyticExprs from the SELECT and ORDER BY clauses
ArrayList<Expr> analyticExprs = Lists.newArrayList();
@@ -778,15 +778,42 @@ private void analyzeAnalytics(Analyzer analyzer)
analyticExprs);
}
if (analyticExprs.isEmpty()) return;
+ ExprSubstitutionMap rewriteSmap = new ExprSubstitutionMap();
+ for (Expr expr: analyticExprs) {
+ AnalyticExpr toRewrite = (AnalyticExpr)expr;
+ Expr newExpr = AnalyticExpr.rewrite(toRewrite);
+ if (newExpr != null) {
+ newExpr.analyze(analyzer);
+ if (!rewriteSmap.containsMappingFor(toRewrite)) {
+ rewriteSmap.put(toRewrite, newExpr);
+ }
+ }
+ }
+ if (rewriteSmap.size() > 0) {
+ // Substitute the exprs with their rewritten versions.
+ ArrayList<Expr> updatedAnalyticExprs =
+ Expr.substituteList(analyticExprs, rewriteSmap, analyzer, false);
+ // This is to get rid the original exprs which have been rewritten.
+ analyticExprs.clear();
+ // Collect the new exprs introduced through the rewrite and the non-rewrite exprs.
+ TreeNode.collect(updatedAnalyticExprs, AnalyticExpr.class, analyticExprs);
+ }
+
analyticInfo_ = AnalyticInfo.create(analyticExprs, analyzer);
+ ExprSubstitutionMap smap = analyticInfo_.getSmap();
+ // If 'exprRewritten' is true, we have to compose the new smap with the existing one.
+ if (rewriteSmap.size() > 0) {
+ smap = ExprSubstitutionMap.compose(
+ rewriteSmap, analyticInfo_.getSmap(), analyzer);
+ }
// change select list and ordering exprs to point to analytic output. We need
// to reanalyze the exprs at this point.
- resultExprs_ = Expr.substituteList(resultExprs_, analyticInfo_.getSmap(), analyzer,
+ resultExprs_ = Expr.substituteList(resultExprs_, smap, analyzer,
false);
LOG.trace("post-analytic selectListExprs: " + Expr.debugString(resultExprs_));
if (sortInfo_ != null) {
- sortInfo_.substituteOrderingExprs(analyticInfo_.getSmap(), analyzer);
+ sortInfo_.substituteOrderingExprs(smap, analyzer);
LOG.trace("post-analytic orderingExprs: " +
Expr.debugString(sortInfo_.getOrderingExprs()));
}
@@ -815,6 +815,14 @@ private void initAggregateBuiltins() {
prefix + "10CountMergeEPN10impala_udf15FunctionContextERKNS1_9BigIntValEPS4_",
null, null));
+ // The following 3 functions are never directly executed because they get rewritten
+ db.addBuiltin(AggregateFunction.createAnalyticBuiltin(
+ db, "percent_rank", Lists.<Type>newArrayList(), Type.DOUBLE, Type.STRING));
+ db.addBuiltin(AggregateFunction.createAnalyticBuiltin(
+ db, "cume_dist", Lists.<Type>newArrayList(), Type.DOUBLE, Type.STRING));
+ db.addBuiltin(AggregateFunction.createAnalyticBuiltin(
+ db, "ntile", Lists.<Type>newArrayList(Type.BIGINT), Type.BIGINT, Type.STRING));
+
for (Type t: Type.getSupportedTypes()) {
if (t.isNull()) continue; // NULL is handled through type promotion.
if (t.isScalarType(PrimitiveType.CHAR)) continue; // promoted to STRING
@@ -962,6 +962,34 @@ public void TestAnalyticExprs() throws AnalysisException {
// + "from functional.alltypes",
// "Only one ORDER BY expression allowed if used with a RANGE window with "
// + "PRECEDING/FOLLOWING");
+
+ // percent_rank(), cume_dist() and ntile() tests
+ AnalyzesOk("select percent_rank() over(order by id) from functional.alltypes");
+ AnalyzesOk("select cume_dist() over(order by id) from functional.alltypes");
+ AnalyzesOk("select ntile(3) over(order by id) from functional.alltypes");
+ AnalyzesOk("select ntile(3000) over(order by id) from functional.alltypes");
+ AnalyzesOk("select ntile(3000000000) over(order by id) from functional.alltypes");
+ AnalyzesOk("select percent_rank() over(partition by tinyint_col, bool_col "
+ + "order by id), ntile(3) over(partition by int_col, bool_col "
+ + "order by smallint_col, id), cume_dist() over(partition by int_col, bool_col "
+ + "order by month) from functional.alltypes");
+
+ AnalysisError("select ntile(-1) over(order by int_col) from functional.alltypestiny",
+ "NTILE() requires a positive argument: -1");
+ AnalysisError("select percent_rank() over(partition by int_col) "
+ + "from functional.alltypestiny",
+ "'percent_rank()' requires an ORDER BY clause");
+ AnalysisError("select cume_dist() over(partition by int_col) "
+ + "from functional.alltypestiny",
+ "'cume_dist()' requires an ORDER BY clause");
+ AnalysisError("select ntile(2) over(partition by int_col) "
+ + "from functional.alltypestiny",
+ "'ntile(2)' requires an ORDER BY clause");
+ // TODO: Remove this test once we allow for non-constant arguments in ntile()
+ AnalysisError(
+ "select ntile(int_col) over(order by tinyint_col) "
+ + "from functional.alltypestiny",
+ "NTILE() requires a constant argument");
}
/**
Oops, something went wrong.

0 comments on commit 1e65d1a

Please sign in to comment.