Skip to content

Commit

Permalink
Internal duckdb#861: Aggregation Absorb API
Browse files Browse the repository at this point in the history
Add an optional destructive combine API called "absorb".
  • Loading branch information
hawkfish committed Dec 11, 2023
1 parent e117c34 commit fac754c
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/common/row_operations/row_aggregate.cpp
Expand Up @@ -81,9 +81,9 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la
idx_t offset = layout.GetAggrOffset();

for (auto &aggr : layout.GetAggregates()) {
D_ASSERT(aggr.function.combine);
D_ASSERT(aggr.function.absorb);
AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator);
aggr.function.combine(sources, targets, aggr_input_data, count);
aggr.function.absorb(sources, targets, aggr_input_data, count);

// Move to the next aggregate states
VectorOperations::AddInPlace(sources, aggr.payload_size, count);
Expand Down
4 changes: 2 additions & 2 deletions src/execution/operator/aggregate/grouped_aggregate_data.cpp
Expand Up @@ -33,7 +33,7 @@ void GroupedAggregateData::InitializeGroupby(vector<unique_ptr<Expression>> grou
filter_count++;
payload_types_filters.push_back(aggr.filter->return_type);
}
if (!aggr.function.combine) {
if (!aggr.function.absorb) {
throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name);
}
aggregates.push_back(std::move(expr));
Expand Down Expand Up @@ -63,7 +63,7 @@ void GroupedAggregateData::InitializeDistinct(const unique_ptr<Expression> &aggr
filter_count++;
}
}
if (!aggr.function.combine) {
if (!aggr.function.absorb) {
throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name);
}
}
Expand Down
Expand Up @@ -35,7 +35,7 @@ PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &contex
bindings.push_back(&aggr);

D_ASSERT(!aggr.IsDistinct());
D_ASSERT(aggr.function.combine);
D_ASSERT(aggr.function.absorb);
for (auto &child : aggr.children) {
payload_types.push_back(child->return_type);
}
Expand Down
Expand Up @@ -343,7 +343,7 @@ SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &cont
Vector dest_state(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get())));

AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator);
aggregate.function.combine(source_state, dest_state, aggr_input_data, 1);
aggregate.function.absorb(source_state, dest_state, aggr_input_data, 1);
#ifdef DEBUG
gstate.state.counts[aggr_idx] += lstate.state.counts[aggr_idx];
#endif
Expand Down Expand Up @@ -541,7 +541,7 @@ void UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() {

Vector state_vec(Value::POINTER(CastPointerToValue(state.aggregates[agg_idx].get())));
Vector combined_vec(Value::POINTER(CastPointerToValue(gstate.state.aggregates[agg_idx].get())));
aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1);
aggregate.function.absorb(state_vec, combined_vec, aggr_input_data, 1);
}

D_ASSERT(!gstate.finished);
Expand Down
2 changes: 1 addition & 1 deletion src/execution/physical_plan/plan_aggregate.cpp
Expand Up @@ -141,7 +141,7 @@ static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate
}
for (auto &expression : op.expressions) {
auto &aggregate = expression->Cast<BoundAggregateExpression>();
if (aggregate.IsDistinct() || !aggregate.function.combine) {
if (aggregate.IsDistinct() || !aggregate.function.absorb) {
// distinct aggregates are not supported in perfect hash aggregates
return false;
}
Expand Down
5 changes: 3 additions & 2 deletions src/function/scalar/system/aggregate_export.cpp
Expand Up @@ -172,7 +172,7 @@ static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Ve
memcpy(local_state.state_buffer1.get(), state1.GetData(), bind_data.state_size);

AggregateInputData aggr_input_data(nullptr, local_state.allocator);
bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1);
bind_data.aggr.absorb(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1);

result_ptr[i] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(local_state.state_buffer1.get()),
bind_data.state_size);
Expand Down Expand Up @@ -299,7 +299,7 @@ static unique_ptr<FunctionData> ExportStateScalarDeserialize(Deserializer &deser
unique_ptr<BoundAggregateExpression>
ExportAggregateFunction::Bind(unique_ptr<BoundAggregateExpression> child_aggregate) {
auto &bound_function = child_aggregate->function;
if (!bound_function.combine) {
if (!bound_function.absorb) {
throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name);
}
if (bound_function.bind) {
Expand Down Expand Up @@ -329,6 +329,7 @@ ExportAggregateFunction::Bind(unique_ptr<BoundAggregateExpression> child_aggrega
bound_function.combine, ExportAggregateFinalize, bound_function.simple_update,
/* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr,
/* can't propagate statistics */ nullptr, nullptr);
export_function.absorb = bound_function.absorb;
export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
export_function.serialize = ExportStateAggregateSerialize;
export_function.deserialize = ExportStateAggregateDeserialize;
Expand Down
12 changes: 12 additions & 0 deletions src/include/duckdb/common/vector_operations/aggregate_executor.hpp
Expand Up @@ -340,6 +340,18 @@ class AggregateExecutor {
}
}

template <class STATE_TYPE, class OP>
static void Absorb(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) {
D_ASSERT(source.GetType().id() == LogicalTypeId::POINTER && target.GetType().id() == LogicalTypeId::POINTER);
// Destructive combine
auto sdata = FlatVector::GetData<STATE_TYPE *>(source);
auto tdata = FlatVector::GetData<STATE_TYPE *>(target);

for (idx_t i = 0; i < count; i++) {
OP::template Absorb<STATE_TYPE, OP>(*sdata[i], *tdata[i], aggr_input_data);
}
}

template <class STATE_TYPE, class RESULT_TYPE, class OP>
static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
idx_t offset) {
Expand Down
23 changes: 16 additions & 7 deletions src/include/duckdb/function/aggregate_function.hpp
Expand Up @@ -91,9 +91,10 @@ class AggregateFunction : public BaseScalarFunction {
aggregate_deserialize_t deserialize = nullptr)
: BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS,
LogicalType(LogicalTypeId::INVALID), null_handling),
state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize),
simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics),
serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) {
state_size(state_size), initialize(initialize), update(update), combine(combine), absorb(combine),
finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor),
statistics(statistics), serialize(serialize), deserialize(deserialize),
order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) {
}

AggregateFunction(const string &name, const vector<LogicalType> &arguments, const LogicalType &return_type,
Expand All @@ -105,9 +106,10 @@ class AggregateFunction : public BaseScalarFunction {
aggregate_deserialize_t deserialize = nullptr)
: BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS,
LogicalType(LogicalTypeId::INVALID)),
state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize),
simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics),
serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) {
state_size(state_size), initialize(initialize), update(update), combine(combine), absorb(combine),
finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor),
statistics(statistics), serialize(serialize), deserialize(deserialize),
order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) {
}

AggregateFunction(const vector<LogicalType> &arguments, const LogicalType &return_type, aggregate_size_t state_size,
Expand Down Expand Up @@ -139,8 +141,10 @@ class AggregateFunction : public BaseScalarFunction {
aggregate_initialize_t initialize;
//! The hashed aggregate update state function
aggregate_update_t update;
//! The hashed aggregate combine states function
//! The non-destructive hashed aggregate combine states function
aggregate_combine_t combine;
//! The (possibly) destructive hashed aggregate combine states function (may be == combine)
aggregate_combine_t absorb;
//! The hashed aggregate finalization function
aggregate_finalize_t finalize;
//! The simple aggregate update function (may be null)
Expand Down Expand Up @@ -278,6 +282,11 @@ class AggregateFunction : public BaseScalarFunction {
AggregateExecutor::Combine<STATE, OP>(source, target, aggr_input_data, count);
}

template <class STATE, class OP>
static void StateAbsorb(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) {
AggregateExecutor::Absorb<STATE, OP>(source, target, aggr_input_data, count);
}

template <class STATE, class RESULT_TYPE, class OP>
static void StateFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count,
idx_t offset) {
Expand Down

0 comments on commit fac754c

Please sign in to comment.