Skip to content

Commit

Permalink
[GJ-112] Enable filter pushdown for TPCH Q1&Q6 on arrow backend (subs…
Browse files Browse the repository at this point in the history
…trait-io#113)

* push down filter to scan options

* Update cpp/gazelle-cpp/compute/substrait_arrow.cc

Co-authored-by: david <david.caiq@gmail.com>

* lint

* update function namespace

Co-authored-by: david <david.caiq@gmail.com>
  • Loading branch information
marin-ma and QiangCai committed Apr 14, 2022
1 parent 5d62f26 commit a0c1ec4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
62 changes: 62 additions & 0 deletions cpp/gazelle-cpp/compute/substrait_arrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <arrow/compute/exec/options.h>
#include <arrow/compute/registry.h>
#include <arrow/dataset/scanner.h>

#include "jni/exec_backend.h"

Expand Down Expand Up @@ -72,6 +73,8 @@ ArrowExecBackend::GetResultIterator(
ReplaceSourceDecls(std::move(source_decls));
}

PushDownFilter();

// Make plan
GLUTEN_ASSIGN_OR_THROW(exec_plan_, arrow::compute::ExecPlan::Make());
GLUTEN_ASSIGN_OR_THROW(auto node, decl_->AddToPlan(exec_plan_.get()));
Expand Down Expand Up @@ -100,6 +103,65 @@ ArrowExecBackend::GetResultIterator(
shared_from_this());
}

void ArrowExecBackend::PushDownFilter() {
std::vector<arrow::compute::Declaration*> visited;

visited.push_back(decl_.get());

while (!visited.empty()) {
auto top = visited.back();
visited.pop_back();
for (auto& input : top->inputs) {
auto& input_decl = arrow::util::get<arrow::compute::Declaration>(input);
if (input_decl.factory_name == "filter" && input_decl.inputs.size() == 1) {
auto scan_decl =
arrow::util::get<arrow::compute::Declaration>(input_decl.inputs[0]);
if (scan_decl.factory_name == "scan") {
auto expression =
arrow::internal::checked_pointer_cast<arrow::compute::FilterNodeOptions>(
input_decl.options)
->filter_expression;
auto scan_options =
arrow::internal::checked_pointer_cast<arrow::dataset::ScanNodeOptions>(
scan_decl.options);
const auto& schema = scan_options->dataset->schema();
FieldPathToName(&expression, schema);
scan_options->scan_options->filter = std::move(expression);
continue;
}
}
visited.push_back(&input_decl);
}
}
}

void ArrowExecBackend::FieldPathToName(arrow::compute::Expression* expression,
const std::shared_ptr<arrow::Schema>& schema) {
std::vector<arrow::compute::Expression*> visited;

visited.push_back(expression);

while (!visited.empty()) {
auto expr = visited.back();
visited.pop_back();
if (expr->call()) {
auto call = const_cast<arrow::compute::Expression::Call*>(expr->call());
std::transform(call->arguments.begin(), call->arguments.end(),
std::back_inserter(visited),
[](arrow::compute::Expression& arg) { return &arg; });
} else if (expr->field_ref()) {
auto field_ref = const_cast<arrow::FieldRef*>(expr->field_ref());
if (auto field_path = field_ref->field_path()) {
*expr =
arrow::compute::field_ref(schema->field((field_path->indices())[0])->name());
} else {
throw gluten::JniPendingException("Field Ref is not field path: " +
field_ref->ToString());
}
}
}
}

void ArrowExecBackend::ReplaceSourceDecls(
std::vector<arrow::compute::Declaration> source_decls) {
std::vector<arrow::compute::Declaration*> visited;
Expand Down
3 changes: 3 additions & 0 deletions cpp/gazelle-cpp/compute/substrait_arrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class ArrowExecBackend : public gluten::ExecBackendBase {
std::shared_ptr<arrow::compute::ExecPlan> exec_plan_;

void ReplaceSourceDecls(std::vector<arrow::compute::Declaration> source_decls);
void PushDownFilter();
static void FieldPathToName(arrow::compute::Expression* expression,
const std::shared_ptr<arrow::Schema>& schema);
};

void Initialize();
Expand Down

0 comments on commit a0c1ec4

Please sign in to comment.