diff --git a/cpp/proto/substrait/extension_rels.proto b/cpp/proto/substrait/extension_rels.proto index 0392405f804bc..78c11b7d7e22d 100644 --- a/cpp/proto/substrait/extension_rels.proto +++ b/cpp/proto/substrait/extension_rels.proto @@ -44,3 +44,17 @@ message AsOfJoinRel { repeated .substrait.Expression by = 2; } } + +// Named tap relation +// +// A tap is a relation having a single input relation that it passes through, while also +// causing some side-effect, e.g., writing to external storage. +message NamedTapRel { + // The kind of tap + string kind = 1; + // A name used to configure the tap, e.g., a URI defining the destination of writing + string name = 2; + // Column names for the tap's output. If specified there must be one name per field. + // If empty, field names will be automatically generated. + repeated string columns = 3; +} diff --git a/cpp/src/arrow/engine/substrait/options.cc b/cpp/src/arrow/engine/substrait/options.cc index be23ce1e64cbb..b4b10a021d096 100644 --- a/cpp/src/arrow/engine/substrait/options.cc +++ b/cpp/src/arrow/engine/substrait/options.cc @@ -30,23 +30,39 @@ namespace arrow { namespace engine { +namespace { + +std::vector MakeDeclarationInputs( + const std::vector& inputs) { + std::vector input_decls(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + input_decls[i] = inputs[i].declaration; + } + return input_decls; +} + +} // namespace + class BaseExtensionProvider : public ExtensionProvider { public: - Result MakeRel(const std::vector& inputs, + Result MakeRel(const ConversionOptions& conv_opts, + const std::vector& inputs, const ExtensionDetails& ext_details, const ExtensionSet& ext_set) override { auto details = dynamic_cast(ext_details); - return MakeRel(inputs, details.rel, ext_set); + return MakeRel(conv_opts, inputs, details.rel, ext_set); } - virtual Result MakeRel(const std::vector& inputs, + virtual Result MakeRel(const ConversionOptions& conv_opts, + const std::vector& inputs, const google::protobuf::Any& rel, const ExtensionSet& ext_set) = 0; }; class DefaultExtensionProvider : public BaseExtensionProvider { public: - Result MakeRel(const std::vector& inputs, + Result MakeRel(const ConversionOptions& conv_opts, + const std::vector& inputs, const google::protobuf::Any& rel, const ExtensionSet& ext_set) override { if (rel.Is()) { @@ -54,6 +70,11 @@ class DefaultExtensionProvider : public BaseExtensionProvider { rel.UnpackTo(&as_of_join_rel); return MakeAsOfJoinRel(inputs, as_of_join_rel, ext_set); } + if (rel.Is()) { + substrait_ext::NamedTapRel named_tap_rel; + rel.UnpackTo(&named_tap_rel); + return MakeNamedTapRel(conv_opts, inputs, named_tap_rel, ext_set); + } return Status::NotImplemented("Unrecognized extension in Susbstrait plan: ", rel.DebugString()); } @@ -113,15 +134,38 @@ class DefaultExtensionProvider : public BaseExtensionProvider { compute::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), tolerance}; // declaration - std::vector input_decls(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - input_decls[i] = inputs[i].declaration; - } + auto input_decls = MakeDeclarationInputs(inputs); return RelationInfo{ {compute::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)), std::move(schema)}, std::move(field_output_indices)}; } + + Result MakeNamedTapRel(const ConversionOptions& conv_opts, + const std::vector& inputs, + const substrait_ext::NamedTapRel& named_tap_rel, + const ExtensionSet& ext_set) { + if (inputs.size() != 1) { + return Status::Invalid( + "substrait_ext::NamedTapRel requires a single input but got: ", inputs.size()); + } + + auto schema = inputs[0].output_schema; + int num_fields = schema->num_fields(); + if (named_tap_rel.columns_size() != num_fields) { + return Status::Invalid("Got ", named_tap_rel.columns_size(), + " NamedTapRel columns but expected ", num_fields); + } + std::vector columns(named_tap_rel.columns().begin(), + named_tap_rel.columns().end()); + ARROW_ASSIGN_OR_RAISE(auto renamed_schema, schema->WithNames(columns)); + auto input_decls = MakeDeclarationInputs(inputs); + ARROW_ASSIGN_OR_RAISE( + auto decl, + conv_opts.named_tap_provider(named_tap_rel.kind(), input_decls, + named_tap_rel.name(), std::move(renamed_schema))); + return RelationInfo{{std::move(decl), std::move(renamed_schema)}, std::nullopt}; + } }; namespace { @@ -143,5 +187,29 @@ void set_default_extension_provider(const std::shared_ptr& pr g_default_extension_provider = provider; } +namespace { + +NamedTapProvider g_default_named_tap_provider = + [](const std::string& tap_kind, std::vector inputs, + const std::string& tap_name, + std::shared_ptr tap_schema) -> Result { + return Status::NotImplemented( + "Plan contained a NamedTapRel but no provider configured"); +}; + +std::mutex g_default_named_tap_provider_mutex; + +} // namespace + +NamedTapProvider default_named_tap_provider() { + std::unique_lock lock(g_default_named_tap_provider_mutex); + return g_default_named_tap_provider; +} + +void set_default_named_tap_provider(NamedTapProvider provider) { + std::unique_lock lock(g_default_named_tap_provider_mutex); + g_default_named_tap_provider = provider; +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index 35a4f70aa9107..3b4a6963ac087 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -23,6 +23,8 @@ #include #include +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" #include "arrow/compute/type_fwd.h" #include "arrow/engine/substrait/type_fwd.h" #include "arrow/engine/substrait/visibility.h" @@ -67,6 +69,10 @@ using NamedTableProvider = std::function(const std::vector&)>; static NamedTableProvider kDefaultNamedTableProvider; +using NamedTapProvider = std::function( + const std::string&, std::vector, const std::string&, + std::shared_ptr)>; + class ARROW_ENGINE_EXPORT ExtensionDetails { public: virtual ~ExtensionDetails() = default; @@ -75,7 +81,8 @@ class ARROW_ENGINE_EXPORT ExtensionDetails { class ARROW_ENGINE_EXPORT ExtensionProvider { public: virtual ~ExtensionProvider() = default; - virtual Result MakeRel(const std::vector& inputs, + virtual Result MakeRel(const ConversionOptions& conv_opts, + const std::vector& inputs, const ExtensionDetails& ext_details, const ExtensionSet& ext_set) = 0; }; @@ -88,6 +95,10 @@ ARROW_ENGINE_EXPORT std::shared_ptr default_extension_provide ARROW_ENGINE_EXPORT void set_default_extension_provider( const std::shared_ptr& provider); +ARROW_ENGINE_EXPORT NamedTapProvider default_named_tap_provider(); + +ARROW_ENGINE_EXPORT void set_default_named_tap_provider(NamedTapProvider provider); + /// Options that control the conversion between Substrait and Acero representations of a /// plan. struct ARROW_ENGINE_EXPORT ConversionOptions { @@ -98,6 +109,10 @@ struct ARROW_ENGINE_EXPORT ConversionOptions { /// The default behavior will return an invalid status if the plan has any /// named table relations. NamedTableProvider named_table_provider = kDefaultNamedTableProvider; + /// \brief A custom strategy to be used for obtaining a tap declaration + /// + /// The default provider returns an error + NamedTapProvider named_tap_provider = default_named_tap_provider(); std::shared_ptr extension_provider = default_extension_provider(); }; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 4fb7bb2a78633..19a38cd40e2f0 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -206,7 +206,7 @@ Result GetExtensionRelationInfo(const substrait::Rel& rel, case substrait::Rel::RelTypeCase::kExtensionLeaf: { const auto& ext = rel.extension_leaf(); DefaultExtensionDetails detail{ext.detail()}; - return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set); + return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, ext_set); } case substrait::Rel::RelTypeCase::kExtensionSingle: { @@ -215,7 +215,7 @@ Result GetExtensionRelationInfo(const substrait::Rel& rel, FromProto(ext.input(), ext_set, conv_opts)); inputs.push_back(std::move(input_info)); DefaultExtensionDetails detail{ext.detail()}; - return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set); + return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, ext_set); } case substrait::Rel::RelTypeCase::kExtensionMulti: { @@ -225,7 +225,7 @@ Result GetExtensionRelationInfo(const substrait::Rel& rel, inputs.push_back(std::move(input_info)); } DefaultExtensionDetails detail{ext.detail()}; - return conv_opts.extension_provider->MakeRel(inputs, detail, ext_set); + return conv_opts.extension_provider->MakeRel(conv_opts, inputs, detail, ext_set); } default: { diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index d99d6fa0c4f32..97a2aea3985bb 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -36,6 +36,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/expression_internal.h" +#include "arrow/compute/exec/map_node.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/util.h" @@ -88,6 +89,30 @@ using internal::checked_cast; using internal::hash_combine; namespace engine { +Status AddPassFactory( + const std::string& factory_name, + compute::ExecFactoryRegistry* registry = compute::default_exec_factory_registry()) { + using compute::ExecBatch; + using compute::ExecNode; + using compute::ExecNodeOptions; + using compute::ExecPlan; + struct PassNode : public compute::MapNode { + static Result Make(ExecPlan* plan, std::vector inputs, + const compute::ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "PassNode")); + return plan->EmplaceNode(plan, inputs, inputs[0]->output_schema()); + } + + PassNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema) + : MapNode(plan, inputs, output_schema) {} + + const char* kind_name() const override { return "PassNode"; } + Result ProcessBatch(ExecBatch batch) override { return batch; } + }; + return registry->AddFactory(factory_name, PassNode::Make); +} + const auto kNullConsumer = std::make_shared(); void WriteIpcData(const std::string& path, @@ -5355,5 +5380,92 @@ TEST(Substrait, AsOfJoinDefaultEmit) { CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); } +TEST(Substrait, PlanWithNamedTapExtension) { + // This demos an extension relation + std::string substrait_json = R"({ + "extensionUris": [], + "extensions": [], + "relations": [{ + "root": { + "input": { + "extension_multi": { + "inputs": [ + { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["time", "key", "value"], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "fp64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["T"] + } + } + } + ], + "detail": { + "@type": "/arrow.substrait_ext.NamedTapRel", + "kind" : "pass_for_named_tap", + "name" : "does_not_matter", + "columns" : ["pass_time", "pass_key", "pass_value"] + } + } + }, + "names": ["t", "k", "v"] + } + }], + "expectedTypeUrls": [] + })"; + + ASSERT_OK(AddPassFactory("pass_for_named_tap")); + + std::shared_ptr input_schema = + schema({field("time", int32()), field("key", int32()), field("value", float64())}); + NamedTableProvider table_provider = AlwaysProvideSameTable( + TableFromJSON(input_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"})); + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + conversion_options.named_tap_provider = + [](const std::string& tap_kind, std::vector inputs, + const std::string& tap_name, + std::shared_ptr tap_schema) -> Result { + return compute::Declaration{tap_kind, std::move(inputs), compute::ExecNodeOptions{}}; + }; + + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); + + std::shared_ptr output_schema = + schema({field("t", int32()), field("k", int32()), field("v", float64())}); + auto expected_table = + TableFromJSON(output_schema, {"[[2, 1, 1.1], [4, 1, 2.1], [6, 2, 3.1]]"}); + CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 825a68e68cfbe..0e3732db6ed3f 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1714,6 +1714,21 @@ bool Schema::HasDistinctFieldNames() const { return names.size() == fields.size(); } +Result> Schema::WithNames( + const std::vector& names) const { + if (names.size() != impl_->fields_.size()) { + return Status::Invalid("attempted to rename schema with ", impl_->fields_.size(), + " fields but only ", names.size(), " new names were given"); + } + FieldVector new_fields; + new_fields.reserve(names.size()); + auto names_itr = names.begin(); + for (const auto& field : impl_->fields_) { + new_fields.push_back(field->WithName(*names_itr++)); + } + return schema(std::move(new_fields)); +} + std::shared_ptr Schema::WithMetadata( const std::shared_ptr& metadata) const { return std::make_shared(impl_->fields_, metadata); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index cf58218a7e741..4ea47962314a0 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1968,6 +1968,12 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable, Result> SetField(int i, const std::shared_ptr& field) const; + /// \brief Replace field names with new names + /// + /// \param[in] names new names + /// \return new Schema + Result> WithNames(const std::vector& names) const; + /// \brief Replace key-value metadata with new metadata /// /// \param[in] metadata new KeyValueMetadata