Skip to content

Commit

Permalink
ARROW-16549: [C++] Simplify AggregateNodeOptions aggregates/targets (a…
Browse files Browse the repository at this point in the history
…pache#13150)

This PR is simplifying the existing `AggregateNodeOptions` usage. This work is still in progress and need to evaluate the existing refactor and usage. 

Todos

- [x] Test 
- [ ] Update documentation
- [ ] Update function docs
- [x] Evaluate CI failures (only tested on Mac M1 with C++/Python, need to check if the change breaks other language bindings


Authored-by: Vibhatha Abeykoon <vibhatha@gmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
vibhatha authored and djnavarro committed Jun 28, 2022
1 parent 67536ca commit 0010f97
Show file tree
Hide file tree
Showing 22 changed files with 451 additions and 505 deletions.
22 changes: 8 additions & 14 deletions c_glib/arrow-glib/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,9 +1254,7 @@ garrow_aggregate_node_options_new(GList *aggregations,
gsize n_keys,
GError **error)
{
std::vector<arrow::compute::internal::Aggregate> arrow_aggregates;
std::vector<arrow::FieldRef> arrow_targets;
std::vector<std::string> arrow_names;
std::vector<arrow::compute::Aggregate> arrow_aggregates;
std::vector<arrow::FieldRef> arrow_keys;
for (auto node = aggregations; node; node = node->next) {
auto aggregation_priv = GARROW_AGGREGATION_GET_PRIVATE(node->data);
Expand All @@ -1265,21 +1263,19 @@ garrow_aggregate_node_options_new(GList *aggregations,
function_options =
garrow_function_options_get_raw(aggregation_priv->options);
};
if (function_options) {
arrow_aggregates.push_back({
aggregation_priv->function,
function_options->Copy(),
});
} else {
arrow_aggregates.push_back({aggregation_priv->function, nullptr});
};
std::vector<arrow::FieldRef> arrow_targets;
if (!garrow_field_refs_add(arrow_targets,
aggregation_priv->input,
error,
"[aggregate-node-options][new][input]")) {
return NULL;
}
arrow_names.emplace_back(aggregation_priv->output);
arrow_aggregates.push_back({
aggregation_priv->function,
function_options ? function_options->Copy() : nullptr,
arrow_targets[0],
aggregation_priv->output,
});
}
for (gsize i = 0; i < n_keys; ++i) {
if (!garrow_field_refs_add(arrow_keys,
Expand All @@ -1291,8 +1287,6 @@ garrow_aggregate_node_options_new(GList *aggregations,
}
auto arrow_options =
new arrow::compute::AggregateNodeOptions(std::move(arrow_aggregates),
std::move(arrow_targets),
std::move(arrow_names),
std::move(arrow_keys));
auto options = g_object_new(GARROW_TYPE_AGGREGATE_NODE_OPTIONS,
"options", arrow_options,
Expand Down
2 changes: 1 addition & 1 deletion c_glib/arrow-glib/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ garrow_quantile_options_get_qs(GArrowQuantileOptions *options,
GARROW_AVAILABLE_IN_9_0
void
garrow_quantile_options_set_q(GArrowQuantileOptions *options,
gdouble quantile);
gdouble q);
GARROW_AVAILABLE_IN_9_0
void
garrow_quantile_options_set_qs(GArrowQuantileOptions *options,
Expand Down
9 changes: 3 additions & 6 deletions cpp/examples/arrow/execution_plan_documentation_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,8 @@ arrow::Status SourceScalarAggregateSinkExample(cp::ExecContext& exec_context) {

ARROW_ASSIGN_OR_RAISE(cp::ExecNode * source,
cp::MakeExecNode("source", plan.get(), {}, source_node_options));
auto aggregate_options = cp::AggregateNodeOptions{/*aggregates=*/{{"sum", nullptr}},
/*targets=*/{"a"},
/*names=*/{"sum(a)"}};
auto aggregate_options =
cp::AggregateNodeOptions{/*aggregates=*/{{"sum", nullptr, "a", "sum(a)"}}};
ARROW_ASSIGN_OR_RAISE(
cp::ExecNode * aggregate,
cp::MakeExecNode("aggregate", plan.get(), {source}, std::move(aggregate_options)));
Expand Down Expand Up @@ -541,9 +540,7 @@ arrow::Status SourceGroupAggregateSinkExample(cp::ExecContext& exec_context) {
cp::MakeExecNode("source", plan.get(), {}, source_node_options));
auto options = std::make_shared<cp::CountOptions>(cp::CountOptions::ONLY_VALID);
auto aggregate_options =
cp::AggregateNodeOptions{/*aggregates=*/{{"hash_count", options}},
/*targets=*/{"a"},
/*names=*/{"count(a)"},
cp::AggregateNodeOptions{/*aggregates=*/{{"hash_count", options, "a", "count(a)"}},
/*keys=*/{"b"}};
ARROW_ASSIGN_OR_RAISE(
cp::ExecNode * aggregate,
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,20 @@ ARROW_EXPORT
Result<Datum> Index(const Datum& value, const IndexOptions& options,
ExecContext* ctx = NULLPTR);

namespace internal {

/// \brief Configure a grouped aggregation
struct ARROW_EXPORT Aggregate {
/// the name of the aggregation function
std::string function;

/// options for the aggregation function
std::shared_ptr<FunctionOptions> options;

// fields to which aggregations will be applied
FieldRef target;

// output field name for aggregations
std::string name;
};

} // namespace internal
} // namespace compute
} // namespace arrow
5 changes: 4 additions & 1 deletion cpp/src/arrow/compute/exec/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/compute/row/grouper.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/task_group.h"

namespace arrow {
Expand Down Expand Up @@ -55,7 +56,9 @@ Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
std::vector<std::unique_ptr<KernelState>> states(kernels.size());

for (size_t i = 0; i < aggregates.size(); ++i) {
const FunctionOptions* options = aggregates[i].options.get();
const FunctionOptions* options =
arrow::internal::checked_cast<const FunctionOptions*>(
aggregates[i].options.get());

if (options == nullptr) {
// use known default options for the named function if possible
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/arrow/compute/exec/aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,15 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
ExecContext* ctx = default_exec_context());

Result<std::vector<const HashAggregateKernel*>> GetKernels(
ExecContext* ctx, const std::vector<internal::Aggregate>& aggregates,
ExecContext* ctx, const std::vector<Aggregate>& aggregates,
const std::vector<ValueDescr>& in_descrs);

Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
const std::vector<internal::Aggregate>& aggregates,
const std::vector<ValueDescr>& in_descrs);
const std::vector<Aggregate>& aggregates, const std::vector<ValueDescr>& in_descrs);

Result<FieldVector> ResolveKernels(
const std::vector<internal::Aggregate>& aggregates,
const std::vector<Aggregate>& aggregates,
const std::vector<const HashAggregateKernel*>& kernels,
const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
const std::vector<ValueDescr>& descrs);
Expand Down
27 changes: 12 additions & 15 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace compute {
namespace {

void AggregatesToString(std::stringstream* ss, const Schema& input_schema,
const std::vector<internal::Aggregate>& aggs,
const std::vector<Aggregate>& aggs,
const std::vector<int>& target_field_ids, int indent = 0) {
*ss << "aggregates=[" << std::endl;
for (size_t i = 0; i < aggs.size(); i++) {
Expand All @@ -64,8 +64,7 @@ class ScalarAggregateNode : public ExecNode {
public:
ScalarAggregateNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
std::vector<int> target_field_ids,
std::vector<internal::Aggregate> aggs,
std::vector<int> target_field_ids, std::vector<Aggregate> aggs,
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<std::vector<std::unique_ptr<KernelState>>> states)
: ExecNode(plan, std::move(inputs), {"target"},
Expand All @@ -89,12 +88,12 @@ class ScalarAggregateNode : public ExecNode {
std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size());
FieldVector fields(kernels.size());
const auto& field_names = aggregate_options.names;
std::vector<int> target_field_ids(kernels.size());

for (size_t i = 0; i < kernels.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto match,
FieldRef(aggregate_options.targets[i]).FindOne(input_schema));
ARROW_ASSIGN_OR_RAISE(
auto match,
FieldRef(aggregate_options.aggregates[i].target).FindOne(input_schema));
target_field_ids[i] = match[0];

ARROW_ASSIGN_OR_RAISE(
Expand Down Expand Up @@ -129,7 +128,7 @@ class ScalarAggregateNode : public ExecNode {
ARROW_ASSIGN_OR_RAISE(
auto descr, kernels[i]->signature->out_type().Resolve(&kernel_ctx, {in_type}));

fields[i] = field(field_names[i], std::move(descr.type));
fields[i] = field(aggregate_options.aggregates[i].name, std::move(descr.type));
}

return plan->EmplaceNode<ScalarAggregateNode>(
Expand Down Expand Up @@ -263,7 +262,7 @@ class ScalarAggregateNode : public ExecNode {
}

const std::vector<int> target_field_ids_;
const std::vector<internal::Aggregate> aggs_;
const std::vector<Aggregate> aggs_;
const std::vector<const ScalarAggregateKernel*> kernels_;

std::vector<std::vector<std::unique_ptr<KernelState>>> states_;
Expand All @@ -276,7 +275,7 @@ class GroupByNode : public ExecNode {
public:
GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema, ExecContext* ctx,
std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
std::vector<internal::Aggregate> aggs,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
/*num_outputs=*/1),
Expand All @@ -295,7 +294,6 @@ class GroupByNode : public ExecNode {
const auto& keys = aggregate_options.keys;
// Copy (need to modify options pointer below)
auto aggs = aggregate_options.aggregates;
const auto& field_names = aggregate_options.names;

// Get input schema
auto input_schema = input->output_schema();
Expand All @@ -310,13 +308,11 @@ class GroupByNode : public ExecNode {
// Find input field indices for aggregates
std::vector<int> agg_src_field_ids(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(auto match,
aggregate_options.targets[i].FindOne(*input_schema));
ARROW_ASSIGN_OR_RAISE(auto match, aggs[i].target.FindOne(*input_schema));
agg_src_field_ids[i] = match[0];
}

// Build vector of aggregate source field data types
DCHECK_EQ(aggregate_options.targets.size(), aggs.size());
std::vector<ValueDescr> agg_src_descrs(aggs.size());
for (size_t i = 0; i < aggs.size(); ++i) {
auto agg_src_field_id = agg_src_field_ids[i];
Expand All @@ -342,7 +338,8 @@ class GroupByNode : public ExecNode {

// Aggregate fields come before key fields to match the behavior of GroupBy function
for (size_t i = 0; i < aggs.size(); ++i) {
output_fields[i] = agg_result_fields[i]->WithName(field_names[i]);
output_fields[i] =
agg_result_fields[i]->WithName(aggregate_options.aggregates[i].name);
}
size_t base = aggs.size();
for (size_t i = 0; i < keys.size(); ++i) {
Expand Down Expand Up @@ -660,7 +657,7 @@ class GroupByNode : public ExecNode {

const std::vector<int> key_field_ids_;
const std::vector<int> agg_src_field_ids_;
const std::vector<internal::Aggregate> aggs_;
const std::vector<Aggregate> aggs_;
const std::vector<const HashAggregateKernel*> agg_kernels_;

ThreadIndexer get_thread_index_;
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/arrow/compute/exec/ir_consumer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ Result<Declaration> Convert(const ir::Relation& rel) {
ARROW_ASSIGN_OR_RAISE(auto arg,
Convert(*aggregate->rel()).As<Declaration::Input>());

AggregateNodeOptions opts{{}, {}, {}};
AggregateNodeOptions opts{{}, {}};

if (!aggregate->measures()) return UnexpectedNullField("Aggregate.measures");
for (const ir::Expression* m : *aggregate->measures()) {
Expand All @@ -550,9 +550,8 @@ Result<Declaration> Convert(const ir::Relation& rel) {
"Support for non-FieldRef arguments to Aggregate.measures");
}

opts.aggregates.push_back({call->function_name, nullptr});
opts.targets.push_back(*target);
opts.names.push_back(call->function_name + " " + target->ToString());
opts.aggregates.push_back({call->function_name, nullptr, *target,
call->function_name + " " + target->ToString()});
}

if (!aggregate->groupings()) return UnexpectedNullField("Aggregate.groupings");
Expand Down
51 changes: 20 additions & 31 deletions cpp/src/arrow/compute/exec/ir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ TEST(Relation, Filter) {
}

TEST(Relation, AggregateSimple) {
ASSERT_THAT(ConvertJSON<ir::Relation>(R"({
ASSERT_THAT(
ConvertJSON<ir::Relation>(R"({
"impl": {
id: {id: 1},
"groupings": [
Expand Down Expand Up @@ -347,28 +348,22 @@ TEST(Relation, AggregateSimple) {
},
"impl_type": "Aggregate"
})"),
ResultWith(Eq(Declaration::Sequence({
{"catalog_source",
CatalogSourceNodeOptions{"tbl", schema({
field("foo", int32()),
field("bar", int64()),
field("baz", float64()),
})},
"0"},
{"aggregate",
AggregateNodeOptions{/*aggregates=*/{
{"sum", nullptr},
{"mean", nullptr},
},
/*targets=*/{1, 2},
/*names=*/
{
"sum FieldRef.FieldPath(1)",
"mean FieldRef.FieldPath(2)",
},
/*keys=*/{0}},
"1"},
}))));
ResultWith(Eq(Declaration::Sequence({
{"catalog_source",
CatalogSourceNodeOptions{"tbl", schema({
field("foo", int32()),
field("bar", int64()),
field("baz", float64()),
})},
"0"},
{"aggregate",
AggregateNodeOptions{/*aggregates=*/{
{"sum", nullptr, 1, "sum FieldRef.FieldPath(1)"},
{"mean", nullptr, 2, "mean FieldRef.FieldPath(2)"},
},
/*keys=*/{0}},
"1"},
}))));
}

TEST(Relation, AggregateWithHaving) {
Expand Down Expand Up @@ -564,14 +559,8 @@ TEST(Relation, AggregateWithHaving) {
{"filter", FilterNodeOptions{less(field_ref(0), literal<int8_t>(3))}, "1"},
{"aggregate",
AggregateNodeOptions{/*aggregates=*/{
{"sum", nullptr},
{"mean", nullptr},
},
/*targets=*/{1, 2},
/*names=*/
{
"sum FieldRef.FieldPath(1)",
"mean FieldRef.FieldPath(2)",
{"sum", nullptr, 1, "sum FieldRef.FieldPath(1)"},
{"mean", nullptr, 2, "mean FieldRef.FieldPath(2)"},
},
/*keys=*/{0}},
"2"},
Expand Down
16 changes: 4 additions & 12 deletions cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,12 @@ class ARROW_EXPORT ProjectNodeOptions : public ExecNodeOptions {
/// \brief Make a node which aggregates input batches, optionally grouped by keys.
class ARROW_EXPORT AggregateNodeOptions : public ExecNodeOptions {
public:
AggregateNodeOptions(std::vector<internal::Aggregate> aggregates,
std::vector<FieldRef> targets, std::vector<std::string> names,
std::vector<FieldRef> keys = {})
: aggregates(std::move(aggregates)),
targets(std::move(targets)),
names(std::move(names)),
keys(std::move(keys)) {}
explicit AggregateNodeOptions(std::vector<Aggregate> aggregates,
std::vector<FieldRef> keys = {})
: aggregates(std::move(aggregates)), keys(std::move(keys)) {}

// aggregations which will be applied to the targetted fields
std::vector<internal::Aggregate> aggregates;
// fields to which aggregations will be applied
std::vector<FieldRef> targets;
// output field names for aggregations
std::vector<std::string> names;
std::vector<Aggregate> aggregates;
// keys by which aggregations will be grouped
std::vector<FieldRef> keys;
};
Expand Down
Loading

0 comments on commit 0010f97

Please sign in to comment.