Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.ViewExpanders;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
Expand Down Expand Up @@ -2769,19 +2770,44 @@ public RelNode visitFillNull(FillNull node, CalcitePlanContext context) {
return context.relBuilder.peek();
}

/** Window {@code ORDER BY} keys from the current node's collation, or empty if it has none. */
private static List<RexNode> deriveCollationOrderKeys(CalcitePlanContext context) {
RelBuilder relBuilder = context.relBuilder;
List<RelCollation> collations =
relBuilder.getCluster().getMetadataQuery().collations(relBuilder.peek());
if (collations == null || collations.isEmpty()) {
return List.of();
}
List<RexNode> orderKeys = new ArrayList<>();
for (RelFieldCollation fieldCollation : collations.get(0).getFieldCollations()) {
RexNode key = relBuilder.field(fieldCollation.getFieldIndex());
if (fieldCollation.direction.isDescending()) {
key = relBuilder.desc(key);
}
if (fieldCollation.nullDirection == RelFieldCollation.NullDirection.LAST) {
key = relBuilder.nullsLast(key);
} else if (fieldCollation.nullDirection == RelFieldCollation.NullDirection.FIRST) {
key = relBuilder.nullsFirst(key);
}
orderKeys.add(key);
}
return orderKeys;
}

@Override
public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) {
// 1. resolve main plan
visitChildren(node, context);
// 2. add row_number() column to main
// 2. add row_number() column to main, ordered by its collation so the zip is deterministic
List<RexNode> mainOrderKeys = deriveCollationOrderKeys(context);
RexNode mainRowNumber =
PlanUtils.makeOver(
context,
BuiltinFunctionName.ROW_NUMBER,
null,
List.of(),
List.of(),
List.of(),
mainOrderKeys,
WindowFrame.toCurrentRow());
context.relBuilder.projectPlus(
context.relBuilder.alias(mainRowNumber, ROW_NUMBER_COLUMN_FOR_MAIN));
Expand All @@ -2791,15 +2817,16 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) {
transformPlanToAttachChild(node.getSubSearch(), relation);
// 4. resolve subsearch plan
node.getSubSearch().accept(this, context);
// 5. add row_number() column to subsearch
// 5. add row_number() column to subsearch, ordered by its collation
List<RexNode> subsearchOrderKeys = deriveCollationOrderKeys(context);
RexNode subsearchRowNumber =
PlanUtils.makeOver(
context,
BuiltinFunctionName.ROW_NUMBER,
null,
List.of(),
List.of(),
List.of(),
subsearchOrderKeys,
WindowFrame.toCurrentRow());
context.relBuilder.projectPlus(
context.relBuilder.alias(subsearchRowNumber, ROW_NUMBER_COLUMN_FOR_SUBSEARCH));
Expand All @@ -2821,6 +2848,11 @@ public RelNode visitAppendCol(AppendCol node, CalcitePlanContext context) {
context.relBuilder.join(
JoinAndLookupUtils.translateJoinType(Join.JoinType.FULL), joinCondition);

// sort by the row numbers (nulls last) so the output order is stable across backends
context.relBuilder.sort(
context.relBuilder.nullsLast(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_MAIN)),
context.relBuilder.nullsLast(context.relBuilder.field(ROW_NUMBER_COLUMN_FOR_SUBSEARCH)));

if (!node.isOverride()) {
// 8. if override = false, drop both _row_number_ columns
context.relBuilder.projectExcept(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,32 @@ public void testAppendcol() {
String expectedLogical =
"LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5],"
+ " COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalJoin(condition=[=($8, $9)], joinType=[full])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(_row_number_subsearch_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalFilter(condition=[=($7, 20)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
+ " LogicalSort(sort0=[$8], sort1=[$9], dir0=[ASC], dir1=[ASC])\n"
+ " LogicalJoin(condition=[=($8, $9)], joinType=[full])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER (ORDER BY $0"
+ " NULLS LAST)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(_row_number_subsearch_=[ROW_NUMBER() OVER (ORDER BY $0 NULLS"
+ " LAST)])\n"
+ " LogicalFilter(condition=[=($7, 20)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 14);

String expectedSparkSql =
"SELECT `t`.`EMPNO`, `t`.`ENAME`, `t`.`JOB`, `t`.`MGR`, `t`.`HIREDATE`, `t`.`SAL`,"
+ " `t`.`COMM`, `t`.`DEPTNO`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
+ " ROW_NUMBER() OVER () `_row_number_main_`\n"
+ " ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST) `_row_number_main_`\n"
+ "FROM `scott`.`EMP`) `t`\n"
+ "FULL JOIN (SELECT ROW_NUMBER() OVER () `_row_number_subsearch_`\n"
+ "FULL JOIN (SELECT ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)"
+ " `_row_number_subsearch_`\n"
+ "FROM `scott`.`EMP`\n"
+ "WHERE `DEPTNO` = 20) `t1` ON `t`.`_row_number_main_` ="
+ " `t1`.`_row_number_subsearch_`";
+ " `t1`.`_row_number_subsearch_`\n"
+ "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t1`.`_row_number_subsearch_` NULLS"
+ " LAST";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand All @@ -54,31 +60,37 @@ public void testAppendcol2() {
String expectedLogical =
"LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5],"
+ " COMM=[$6], DEPTNO=[$7], left_col=[$8], right_col=[$10])\n"
+ " LogicalJoin(condition=[=($9, $11)], joinType=[full])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " LogicalSort(sort0=[$9], sort1=[$11], dir0=[ASC], dir1=[ASC])\n"
+ " LogicalJoin(condition=[=($9, $11)], joinType=[full])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], left_col=[$7], _row_number_main_=[ROW_NUMBER()"
+ " OVER ()])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(right_col=[$8], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalFilter(condition=[=($7, 20)])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " OVER (ORDER BY $0 NULLS LAST)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(right_col=[$8], _row_number_subsearch_=[ROW_NUMBER() OVER"
+ " (ORDER BY $0 NULLS LAST)])\n"
+ " LogicalFilter(condition=[=($7, 20)])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], right_col=[$7])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 14);

String expectedSparkSql =
"SELECT `t`.`EMPNO`, `t`.`ENAME`, `t`.`JOB`, `t`.`MGR`, `t`.`HIREDATE`, `t`.`SAL`,"
+ " `t`.`COMM`, `t`.`DEPTNO`, `t`.`left_col`, `t2`.`right_col`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
+ " `DEPTNO` `left_col`, ROW_NUMBER() OVER () `_row_number_main_`\n"
+ " `DEPTNO` `left_col`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)"
+ " `_row_number_main_`\n"
+ "FROM `scott`.`EMP`) `t`\n"
+ "FULL JOIN (SELECT `right_col`, ROW_NUMBER() OVER () `_row_number_subsearch_`\n"
+ "FULL JOIN (SELECT `right_col`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)"
+ " `_row_number_subsearch_`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
+ " `DEPTNO` `right_col`\n"
+ "FROM `scott`.`EMP`) `t0`\n"
+ "WHERE `DEPTNO` = 20) `t2` ON `t`.`_row_number_main_` ="
+ " `t2`.`_row_number_subsearch_`";
+ " `t2`.`_row_number_subsearch_`\n"
+ "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t2`.`_row_number_subsearch_` NULLS"
+ " LAST";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand All @@ -91,14 +103,17 @@ public void testAppendcolOverride() {
+ " JOB=[CASE(=($8, $17), $11, $2)], MGR=[CASE(=($8, $17), $12, $3)],"
+ " HIREDATE=[CASE(=($8, $17), $13, $4)], SAL=[CASE(=($8, $17), $14, $5)],"
+ " COMM=[CASE(=($8, $17), $15, $6)], DEPTNO=[CASE(=($8, $17), $16, $7)])\n"
+ " LogicalJoin(condition=[=($8, $17)], joinType=[full])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalFilter(condition=[=($7, 20)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
+ " LogicalSort(sort0=[$8], sort1=[$17], dir0=[ASC], dir1=[ASC])\n"
+ " LogicalJoin(condition=[=($8, $17)], joinType=[full])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_main_=[ROW_NUMBER() OVER (ORDER BY $0"
+ " NULLS LAST)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4],"
+ " SAL=[$5], COMM=[$6], DEPTNO=[$7], _row_number_subsearch_=[ROW_NUMBER() OVER (ORDER"
+ " BY $0 NULLS LAST)])\n"
+ " LogicalFilter(condition=[=($7, 20)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
verifyResultCount(root, 14);

Expand All @@ -116,13 +131,16 @@ public void testAppendcolOverride() {
+ " `t`.`COMM` END `COMM`, CASE WHEN `t`.`_row_number_main_` ="
+ " `t1`.`_row_number_subsearch_` THEN `t1`.`DEPTNO` ELSE `t`.`DEPTNO` END `DEPTNO`\n"
+ "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`,"
+ " ROW_NUMBER() OVER () `_row_number_main_`\n"
+ " ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST) `_row_number_main_`\n"
+ "FROM `scott`.`EMP`) `t`\n"
+ "FULL JOIN (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`,"
+ " `DEPTNO`, ROW_NUMBER() OVER () `_row_number_subsearch_`\n"
+ " `DEPTNO`, ROW_NUMBER() OVER (ORDER BY `EMPNO` NULLS LAST)"
+ " `_row_number_subsearch_`\n"
+ "FROM `scott`.`EMP`\n"
+ "WHERE `DEPTNO` = 20) `t1` ON `t`.`_row_number_main_` ="
+ " `t1`.`_row_number_subsearch_`";
+ " `t1`.`_row_number_subsearch_`\n"
+ "ORDER BY `t`.`_row_number_main_` NULLS LAST, `t1`.`_row_number_subsearch_` NULLS"
+ " LAST";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand All @@ -132,16 +150,17 @@ public void testAppendcolStats() {
RelNode root = getRelNode(ppl);
String expectedLogical =
"LogicalProject(count()=[$0], DEPTNO=[$1], avg(SAL)=[$3])\n"
+ " LogicalJoin(condition=[=($2, $4)], joinType=[full])\n"
+ " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER"
+ " LogicalSort(sort0=[$2], sort1=[$4], dir0=[ASC], dir1=[ASC])\n"
+ " LogicalJoin(condition=[=($2, $4)], joinType=[full])\n"
+ " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER"
+ " ()])\n"
+ " LogicalAggregate(group=[{0}], count()=[COUNT()])\n"
+ " LogicalProject(DEPTNO=[$7])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(avg(SAL)=[$1], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
+ " LogicalAggregate(group=[{0}], count()=[COUNT()])\n"
+ " LogicalProject(DEPTNO=[$7])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(avg(SAL)=[$1], _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult =
""
Expand All @@ -159,7 +178,10 @@ public void testAppendcolStats() {
+ "FULL JOIN (SELECT AVG(`SAL`) `avg(SAL)`, ROW_NUMBER() OVER ()"
+ " `_row_number_subsearch_`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` = `t4`.`_row_number_subsearch_`";
+ "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` ="
+ " `t4`.`_row_number_subsearch_`\n"
+ "ORDER BY `t1`.`_row_number_main_` NULLS LAST, `t4`.`_row_number_subsearch_` NULLS"
+ " LAST";
verifyPPLToSparkSQL(root, expectedSparkSql);
}

Expand All @@ -171,17 +193,18 @@ public void testAppendcolStatsOverride() {
RelNode root = getRelNode(ppl);
String expectedLogical =
"LogicalProject(count()=[$0], DEPTNO=[CASE(=($2, $5), $4, $1)], avg(SAL)=[$3])\n"
+ " LogicalJoin(condition=[=($2, $5)], joinType=[full])\n"
+ " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER"
+ " LogicalSort(sort0=[$2], sort1=[$5], dir0=[ASC], dir1=[ASC])\n"
+ " LogicalJoin(condition=[=($2, $5)], joinType=[full])\n"
+ " LogicalProject(count()=[$1], DEPTNO=[$0], _row_number_main_=[ROW_NUMBER() OVER"
+ " ()])\n"
+ " LogicalAggregate(group=[{0}], count()=[COUNT()])\n"
+ " LogicalProject(DEPTNO=[$7])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(avg(SAL)=[$1], DEPTNO=[$0], _row_number_subsearch_=[ROW_NUMBER()"
+ " OVER ()])\n"
+ " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
+ " LogicalAggregate(group=[{0}], count()=[COUNT()])\n"
+ " LogicalProject(DEPTNO=[$7])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalProject(avg(SAL)=[$1], DEPTNO=[$0],"
+ " _row_number_subsearch_=[ROW_NUMBER() OVER ()])\n"
+ " LogicalAggregate(group=[{0}], avg(SAL)=[AVG($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
verifyLogical(root, expectedLogical);
String expectedResult =
""
Expand All @@ -200,7 +223,10 @@ public void testAppendcolStatsOverride() {
+ "FULL JOIN (SELECT AVG(`SAL`) `avg(SAL)`, `DEPTNO`, ROW_NUMBER() OVER ()"
+ " `_row_number_subsearch_`\n"
+ "FROM `scott`.`EMP`\n"
+ "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` = `t4`.`_row_number_subsearch_`";
+ "GROUP BY `DEPTNO`) `t4` ON `t1`.`_row_number_main_` ="
+ " `t4`.`_row_number_subsearch_`\n"
+ "ORDER BY `t1`.`_row_number_main_` NULLS LAST, `t4`.`_row_number_subsearch_` NULLS"
+ " LAST";
verifyPPLToSparkSQL(root, expectedSparkSql);
}
}
Loading