Skip to content

Commit

Permalink
progress 3/17
Browse files Browse the repository at this point in the history
  • Loading branch information
julianhyde committed Mar 18, 2023
1 parent 5c96775 commit b0e3107
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 6 deletions.
6 changes: 0 additions & 6 deletions core/src/main/java/org/apache/calcite/plan/RelOptUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.rules.MeasureRules;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.stream.StreamRules;
import org.apache.calcite.rel.type.RelDataType;
Expand Down Expand Up @@ -2040,11 +2039,6 @@ public static void registerDefaultRules(RelOptPlanner planner,
planner.addRule(rule);
}
}
if (true) {
for (RelOptRule rule : MeasureRules.rules()) {
planner.addRule(rule);
}
}
// Registers this rule for default ENUMERABLE convention
// because:
// 1. ScannableTable can bind data directly;
Expand Down
69 changes: 69 additions & 0 deletions core/src/main/java/org/apache/calcite/rel/RelNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,21 @@
*/
package org.apache.calcite.rel;

import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.runtime.Utilities;
import org.apache.calcite.util.Util;

import com.google.common.collect.Ordering;

import java.util.Comparator;
import java.util.function.BiConsumer;
import java.util.function.Predicate;

/**
* Utilities concerning relational expressions.
Expand Down Expand Up @@ -51,6 +61,65 @@ public static int compareRels(RelNode[] rels0, RelNode[] rels1) {
return 0;
}

/** Returns whether a tree of {@link RelNode}s contains a match for a
* {@link RexNode} finder. */
public static boolean contains(RelNode rel,
Predicate<AggregateCall> aggPredicate, RexUtil.RexFinder finder) {
try {
findRex(rel, finder, aggPredicate, (relNode, rexNode) -> {
throw Util.FoundOne.NULL;
});
return false;
} catch (Util.FoundOne e) {
return true;
}
}

/** Searches for expressions in a tree of {@link RelNode}s. */
// TODO: a new method RelNode.accept(RexVisitor, BiConsumer), with similar
// overrides to RelNode.accept(RexShuttle), would be better.
public static void findRex(RelNode rel, RexUtil.RexFinder finder,
Predicate<AggregateCall> aggPredicate,
BiConsumer<RelNode, RexNode> consumer) {
if (rel instanceof Filter) {
Filter filter = (Filter) rel;
try {
filter.getCondition().accept(finder);
} catch (Util.FoundOne e) {
consumer.accept(filter, (RexNode) e.getNode());
}
}
if (rel instanceof Project) {
Project project = (Project) rel;
for (RexNode node : project.getProjects()) {
try {
node.accept(finder);
} catch (Util.FoundOne e) {
consumer.accept(project, (RexNode) e.getNode());
}
}
}
if (rel instanceof Join) {
Join join = (Join) rel;
try {
join.getCondition().accept(finder);
} catch (Util.FoundOne e) {
consumer.accept(join, (RexNode) e.getNode());
}
}
if (rel instanceof Aggregate) {
Aggregate aggregate = (Aggregate) rel;
for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
if (aggPredicate.test(aggregateCall)) {
consumer.accept(aggregate, null);
}
}
}
for (RelNode input : rel.getInputs()) {
findRex(input, finder, aggPredicate, consumer);
}
}

/** Arbitrary stable comparator for {@link RelNode}s. */
private static class RelNodeComparator implements Comparator<RelNode> {
@Override public int compare(RelNode o1, RelNode o2) {
Expand Down
54 changes: 54 additions & 0 deletions core/src/main/java/org/apache/calcite/tools/Programs.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,19 @@
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelNodes;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.ChainedRelMetadataProvider;
import org.apache.calcite.rel.metadata.DefaultRelMetadataProvider;
import org.apache.calcite.rel.metadata.RelMetadataProvider;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.rules.JoinPushThroughJoinRule;
import org.apache.calcite.rel.rules.MeasureRules;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql2rel.RelDecorrelator;
import org.apache.calcite.sql2rel.RelFieldTrimmer;
import org.apache.calcite.sql2rel.SqlToRelConverter;
Expand All @@ -51,6 +57,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;

import static org.apache.calcite.linq4j.Nullness.castNonNull;

Expand Down Expand Up @@ -141,6 +148,12 @@ public static Program sequence(Program... programs) {
return new SequenceProgram(ImmutableList.copyOf(programs));
}

/** Creates a program that executes if a predicate is true. */
public static Program conditional(Predicate<RelNode> predicate,
Program program) {
return new ConditionalProgram(predicate, program);
}

/** Creates a program that executes a list of rules in a HEP planner. */
public static Program hep(Iterable<? extends RelOptRule> rules,
boolean noDag, RelMetadataProvider metadataProvider) {
Expand Down Expand Up @@ -243,6 +256,19 @@ public static Program subQuery(RelMetadataProvider metadataProvider) {
return of(builder.build(), true, metadataProvider);
}

public static Program measure(RelMetadataProvider metadataProvider) {
return conditional(Programs::containsAggM2v,
sequence(hep(MeasureRules.rules(), true, metadataProvider),
subQuery(metadataProvider),
new DecorrelateProgram()));
}

private static boolean containsAggM2v(RelNode rel) {
return RelNodes.contains(rel,
aggCall -> aggCall.getAggregation().kind == SqlKind.AGG_M2V,
RexUtil.find(SqlKind.AGG_M2V));
}

@Deprecated
public static Program getProgram() {
return (planner, rel, requiredOutputTraits, materializations, lattices) ->
Expand Down Expand Up @@ -281,6 +307,7 @@ public static Program standard(RelMetadataProvider metadataProvider) {

return sequence(subQuery(metadataProvider),
new DecorrelateProgram(),
measure(metadataProvider),
new TrimFieldsProgram(),
program1,

Expand Down Expand Up @@ -333,9 +360,36 @@ private static class SequenceProgram implements Program {
RelTraitSet requiredOutputTraits,
List<RelOptMaterialization> materializations,
List<RelOptLattice> lattices) {
int i = 0;
for (Program program : programs) {
rel = program.run(
planner, rel, requiredOutputTraits, materializations, lattices);

System.out.println(
RelOptUtil.dumpPlan("pass #" + i++, rel, SqlExplainFormat.TEXT,
SqlExplainLevel.DIGEST_ATTRIBUTES));
}
return rel;
}
}

/** Program that runs a sub-program only if a condition is true. */
private static class ConditionalProgram implements Program {
private final Predicate<RelNode> predicate;
private final Program program;

ConditionalProgram(Predicate<RelNode> predicate, Program program) {
this.predicate = predicate;
this.program = program;
}

@Override public RelNode run(RelOptPlanner planner, RelNode rel,
RelTraitSet requiredOutputTraits,
List<RelOptMaterialization> materializations,
List<RelOptLattice> lattices) {
if (predicate.test(rel)) {
return program.run(planner, rel, requiredOutputTraits, materializations,
lattices);
}
return rel;
}
Expand Down

0 comments on commit b0e3107

Please sign in to comment.