Skip to content

Commit

Permalink
apacheGH-33899: [C++] Add NamedTapRel relation as a Substrait extensi…
Browse files Browse the repository at this point in the history
…on (apache#33909)

See apache#33899. This PR adds `NamedTapRel` and a simple test case with a no-op tap (i.e., just passing-through).
* Closes: apache#33899

Lead-authored-by: Yaron Gvili <rtpsw@hotmail.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
2 people authored and Mike Hancock committed Feb 17, 2023
1 parent 9ce5812 commit f5fd3dd
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 12 deletions.
14 changes: 14 additions & 0 deletions cpp/proto/substrait/extension_rels.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,17 @@ message AsOfJoinRel {
repeated .substrait.Expression by = 2;
}
}

// Named tap relation
//
// A tap is a relation having a single input relation that it passes through, while also
// causing some side-effect, e.g., writing to external storage.
message NamedTapRel {
// The kind of tap
string kind = 1;
// A name used to configure the tap, e.g., a URI defining the destination of writing
string name = 2;
// Column names for the tap's output. If specified there must be one name per field.
// If empty, field names will be automatically generated.
repeated string columns = 3;
}
84 changes: 76 additions & 8 deletions cpp/src/arrow/engine/substrait/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,51 @@
namespace arrow {
namespace engine {

namespace {

std::vector<compute::Declaration::Input> MakeDeclarationInputs(
const std::vector<DeclarationInfo>& inputs) {
std::vector<compute::Declaration::Input> input_decls(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
input_decls[i] = inputs[i].declaration;
}
return input_decls;
}

} // namespace

class BaseExtensionProvider : public ExtensionProvider {
public:
Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) override {
auto details = dynamic_cast<const DefaultExtensionDetails&>(ext_details);
return MakeRel(inputs, details.rel, ext_set);
return MakeRel(conv_opts, inputs, details.rel, ext_set);
}

virtual Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) = 0;
};

class DefaultExtensionProvider : public BaseExtensionProvider {
public:
Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) override {
if (rel.Is<substrait_ext::AsOfJoinRel>()) {
substrait_ext::AsOfJoinRel as_of_join_rel;
rel.UnpackTo(&as_of_join_rel);
return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set);
}
if (rel.Is<substrait_ext::NamedTapRel>()) {
substrait_ext::NamedTapRel named_tap_rel;
rel.UnpackTo(&named_tap_rel);
return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set);
}
return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ",
rel.DebugString());
}
Expand Down Expand Up @@ -113,15 +134,38 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), tolerance};

// declaration
std::vector<compute::Declaration::Input> input_decls(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
input_decls[i] = inputs[i].declaration;
}
auto input_decls = MakeDeclarationInputs(inputs);
return RelationInfo{
{compute::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)),
std::move(schema)},
std::move(field_output_indices)};
}

Result<RelationInfo> MakeNamedTapRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const substrait_ext::NamedTapRel& named_tap_rel,
const ExtensionSet& ext_set) {
if (inputs.size() != 1) {
return Status::Invalid(
"substrait_ext::NamedTapRel requires a single input but got: ", inputs.size());
}

auto schema = inputs[0].output_schema;
int num_fields = schema->num_fields();
if (named_tap_rel.columns_size() != num_fields) {
return Status::Invalid("Got ", named_tap_rel.columns_size(),
" NamedTapRel columns but expected ", num_fields);
}
std::vector<std::string> columns(named_tap_rel.columns().begin(),
named_tap_rel.columns().end());
ARROW_ASSIGN_OR_RAISE(auto renamed_schema, schema->WithNames(columns));
auto input_decls = MakeDeclarationInputs(inputs);
ARROW_ASSIGN_OR_RAISE(
auto decl,
conv_opts.named_tap_provider(named_tap_rel.kind(), input_decls,
named_tap_rel.name(), std::move(renamed_schema)));
return RelationInfo{{std::move(decl), std::move(renamed_schema)}, std::nullopt};
}
};

namespace {
Expand All @@ -143,5 +187,29 @@ void set_default_extension_provider(const std::shared_ptr<ExtensionProvider>& pr
g_default_extension_provider = provider;
}

namespace {

NamedTapProvider g_default_named_tap_provider =
[](const std::string& tap_kind, std::vector<compute::Declaration::Input> inputs,
const std::string& tap_name,
std::shared_ptr<Schema> tap_schema) -> Result<compute::Declaration> {
return Status::NotImplemented(
"Plan contained a NamedTapRel but no provider configured");
};

std::mutex g_default_named_tap_provider_mutex;

} // namespace

NamedTapProvider default_named_tap_provider() {
std::unique_lock<std::mutex> lock(g_default_named_tap_provider_mutex);
return g_default_named_tap_provider;
}

void set_default_named_tap_provider(NamedTapProvider provider) {
std::unique_lock<std::mutex> lock(g_default_named_tap_provider_mutex);
g_default_named_tap_provider = provider;
}

} // namespace engine
} // namespace arrow
17 changes: 16 additions & 1 deletion cpp/src/arrow/engine/substrait/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <string>
#include <vector>

#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/engine/substrait/type_fwd.h"
#include "arrow/engine/substrait/visibility.h"
Expand Down Expand Up @@ -67,6 +69,10 @@ using NamedTableProvider =
std::function<Result<compute::Declaration>(const std::vector<std::string>&)>;
static NamedTableProvider kDefaultNamedTableProvider;

using NamedTapProvider = std::function<Result<compute::Declaration>(
const std::string&, std::vector<compute::Declaration::Input>, const std::string&,
std::shared_ptr<Schema>)>;

class ARROW_ENGINE_EXPORT ExtensionDetails {
public:
virtual ~ExtensionDetails() = default;
Expand All @@ -75,7 +81,8 @@ class ARROW_ENGINE_EXPORT ExtensionDetails {
class ARROW_ENGINE_EXPORT ExtensionProvider {
public:
virtual ~ExtensionProvider() = default;
virtual Result<RelationInfo> MakeRel(const std::vector<DeclarationInfo>& inputs,
virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) = 0;
};
Expand All @@ -88,6 +95,10 @@ ARROW_ENGINE_EXPORT std::shared_ptr<ExtensionProvider> default_extension_provide
ARROW_ENGINE_EXPORT void set_default_extension_provider(
const std::shared_ptr<ExtensionProvider>& provider);

ARROW_ENGINE_EXPORT NamedTapProvider default_named_tap_provider();

ARROW_ENGINE_EXPORT void set_default_named_tap_provider(NamedTapProvider provider);

/// Options that control the conversion between Substrait and Acero representations of a
/// plan.
struct ARROW_ENGINE_EXPORT ConversionOptions {
Expand All @@ -98,6 +109,10 @@ struct ARROW_ENGINE_EXPORT ConversionOptions {
/// The default behavior will return an invalid status if the plan has any
/// named table relations.
NamedTableProvider named_table_provider = kDefaultNamedTableProvider;
/// \brief A custom strategy to be used for obtaining a tap declaration
///
/// The default provider returns an error
NamedTapProvider named_tap_provider = default_named_tap_provider();
std::shared_ptr<ExtensionProvider> extension_provider = default_extension_provider();
};

Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const substrait::Rel& rel,
case substrait::Rel::RelTypeCase::kExtensionLeaf: {
const auto& ext = rel.extension_leaf();
DefaultExtensionDetails detail{ext.detail()};
return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, ext_set);
}

case substrait::Rel::RelTypeCase::kExtensionSingle: {
Expand All @@ -215,7 +215,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const substrait::Rel& rel,
FromProto(ext.input(), ext_set, conv_opts));
inputs.push_back(std::move(input_info));
DefaultExtensionDetails detail{ext.detail()};
return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, ext_set);
}

case substrait::Rel::RelTypeCase::kExtensionMulti: {
Expand All @@ -225,7 +225,7 @@ Result<RelationInfo> GetExtensionRelationInfo(const substrait::Rel& rel,
inputs.push_back(std::move(input_info));
}
DefaultExtensionDetails detail{ext.detail()};
return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set);
return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, ext_set);
}

default: {
Expand Down
112 changes: 112 additions & 0 deletions cpp/src/arrow/engine/substrait/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/expression.h"
#include "arrow/compute/exec/expression_internal.h"
#include "arrow/compute/exec/map_node.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/test_util.h"
#include "arrow/compute/exec/util.h"
Expand Down Expand Up @@ -88,6 +89,30 @@ using internal::checked_cast;
using internal::hash_combine;
namespace engine {

Status AddPassFactory(
const std::string& factory_name,
compute::ExecFactoryRegistry* registry = compute::default_exec_factory_registry()) {
using compute::ExecBatch;
using compute::ExecNode;
using compute::ExecNodeOptions;
using compute::ExecPlan;
struct PassNode : public compute::MapNode {
static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const compute::ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "PassNode"));
return plan->EmplaceNode<PassNode>(plan, inputs, inputs[0]->output_schema());
}

PassNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema)
: MapNode(plan, inputs, output_schema) {}

const char* kind_name() const override { return "PassNode"; }
Result<ExecBatch> ProcessBatch(ExecBatch batch) override { return batch; }
};
return registry->AddFactory(factory_name, PassNode::Make);
}

const auto kNullConsumer = std::make_shared<compute::NullSinkNodeConsumer>();

void WriteIpcData(const std::string& path,
Expand Down Expand Up @@ -5355,5 +5380,92 @@ TEST(Substrait, AsOfJoinDefaultEmit) {
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
}

TEST(Substrait, PlanWithNamedTapExtension) {
// This demos an extension relation
std::string substrait_json = R"({
"extensionUris": [],
"extensions": [],
"relations": [{
"root": {
"input": {
"extension_multi": {
"inputs": [
{
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": ["time", "key", "value"],
"struct": {
"types": [
{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
},
{
"fp64": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": ["T"]
}
}
}
],
"detail": {
"@type": "/arrow.substrait_ext.NamedTapRel",
"kind" : "pass_for_named_tap",
"name" : "does_not_matter",
"columns" : ["pass_time", "pass_key", "pass_value"]
}
}
},
"names": ["t", "k", "v"]
}
}],
"expectedTypeUrls": []
})";

ASSERT_OK(AddPassFactory("pass_for_named_tap"));

std::shared_ptr<Schema> input_schema =
schema({field("time", int32()), field("key", int32()), field("value", float64())});
NamedTableProvider table_provider = AlwaysProvideSameTable(
TableFromJSON(input_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"}));
ConversionOptions conversion_options;
conversion_options.named_table_provider = std::move(table_provider);
conversion_options.named_tap_provider =
[](const std::string& tap_kind, std::vector<compute::Declaration::Input> inputs,
const std::string& tap_name,
std::shared_ptr<Schema> tap_schema) -> Result<compute::Declaration> {
return compute::Declaration{tap_kind, std::move(inputs), compute::ExecNodeOptions{}};
};

ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));

std::shared_ptr<Schema> output_schema =
schema({field("t", int32()), field("k", int32()), field("v", float64())});
auto expected_table =
TableFromJSON(output_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"});
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
}

} // namespace engine
} // namespace arrow
15 changes: 15 additions & 0 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1714,6 +1714,21 @@ bool Schema::HasDistinctFieldNames() const {
return names.size() == fields.size();
}

Result<std::shared_ptr<Schema>> Schema::WithNames(
const std::vector<std::string>& names) const {
if (names.size() != impl_->fields_.size()) {
return Status::Invalid("attempted to rename schema with ", impl_->fields_.size(),
" fields but only ", names.size(), " new names were given");
}
FieldVector new_fields;
new_fields.reserve(names.size());
auto names_itr = names.begin();
for (const auto& field : impl_->fields_) {
new_fields.push_back(field->WithName(*names_itr++));
}
return schema(std::move(new_fields));
}

std::shared_ptr<Schema> Schema::WithMetadata(
const std::shared_ptr<const KeyValueMetadata>& metadata) const {
return std::make_shared<Schema>(impl_->fields_, metadata);
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,12 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable,
Result<std::shared_ptr<Schema>> SetField(int i,
const std::shared_ptr<Field>& field) const;

/// \brief Replace field names with new names
///
/// \param[in] names new names
/// \return new Schema
Result<std::shared_ptr<Schema>> WithNames(const std::vector<std::string>& names) const;

/// \brief Replace key-value metadata with new metadata
///
/// \param[in] metadata new KeyValueMetadata
Expand Down

0 comments on commit f5fd3dd

Please sign in to comment.