From fac754c9fc763aba62c33a3ced95b7b5faee7417 Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Mon, 11 Dec 2023 06:42:00 -0800 Subject: [PATCH 1/8] Internal #861: Aggregation Absorb API Add an optional destructive combine API called "absorb". --- src/common/row_operations/row_aggregate.cpp | 4 ++-- .../aggregate/grouped_aggregate_data.cpp | 4 ++-- .../physical_perfecthash_aggregate.cpp | 2 +- .../physical_ungrouped_aggregate.cpp | 4 ++-- .../physical_plan/plan_aggregate.cpp | 2 +- .../scalar/system/aggregate_export.cpp | 5 ++-- .../vector_operations/aggregate_executor.hpp | 12 ++++++++++ .../duckdb/function/aggregate_function.hpp | 23 +++++++++++++------ 8 files changed, 39 insertions(+), 17 deletions(-) diff --git a/src/common/row_operations/row_aggregate.cpp b/src/common/row_operations/row_aggregate.cpp index 6c89d887006..96dfbb6f55a 100644 --- a/src/common/row_operations/row_aggregate.cpp +++ b/src/common/row_operations/row_aggregate.cpp @@ -81,9 +81,9 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la idx_t offset = layout.GetAggrOffset(); for (auto &aggr : layout.GetAggregates()) { - D_ASSERT(aggr.function.combine); + D_ASSERT(aggr.function.absorb); AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.combine(sources, targets, aggr_input_data, count); + aggr.function.absorb(sources, targets, aggr_input_data, count); // Move to the next aggregate states VectorOperations::AddInPlace(sources, aggr.payload_size, count); diff --git a/src/execution/operator/aggregate/grouped_aggregate_data.cpp b/src/execution/operator/aggregate/grouped_aggregate_data.cpp index d0b52a606bd..fe74c8c57e4 100644 --- a/src/execution/operator/aggregate/grouped_aggregate_data.cpp +++ b/src/execution/operator/aggregate/grouped_aggregate_data.cpp @@ -33,7 +33,7 @@ void GroupedAggregateData::InitializeGroupby(vector> grou filter_count++; payload_types_filters.push_back(aggr.filter->return_type); } - if (!aggr.function.combine) { + if (!aggr.function.absorb) { throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); } aggregates.push_back(std::move(expr)); @@ -63,7 +63,7 @@ void GroupedAggregateData::InitializeDistinct(const unique_ptr &aggr filter_count++; } } - if (!aggr.function.combine) { + if (!aggr.function.absorb) { throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); } } diff --git a/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp b/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp index fe7e6c46163..5553986f563 100644 --- a/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp @@ -35,7 +35,7 @@ PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &contex bindings.push_back(&aggr); D_ASSERT(!aggr.IsDistinct()); - D_ASSERT(aggr.function.combine); + D_ASSERT(aggr.function.absorb); for (auto &child : aggr.children) { payload_types.push_back(child->return_type); } diff --git a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index 8db0ede314c..032195c5d37 100644 --- a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -343,7 +343,7 @@ SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &cont Vector dest_state(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator); - aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); + aggregate.function.absorb(source_state, dest_state, aggr_input_data, 1); #ifdef DEBUG gstate.state.counts[aggr_idx] += lstate.state.counts[aggr_idx]; #endif @@ -541,7 +541,7 @@ void UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() { Vector state_vec(Value::POINTER(CastPointerToValue(state.aggregates[agg_idx].get()))); Vector combined_vec(Value::POINTER(CastPointerToValue(gstate.state.aggregates[agg_idx].get()))); - aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); + aggregate.function.absorb(state_vec, combined_vec, aggr_input_data, 1); } D_ASSERT(!gstate.finished); diff --git a/src/execution/physical_plan/plan_aggregate.cpp b/src/execution/physical_plan/plan_aggregate.cpp index 4160783e2ad..4f84d3b727b 100644 --- a/src/execution/physical_plan/plan_aggregate.cpp +++ b/src/execution/physical_plan/plan_aggregate.cpp @@ -141,7 +141,7 @@ static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate } for (auto &expression : op.expressions) { auto &aggregate = expression->Cast(); - if (aggregate.IsDistinct() || !aggregate.function.combine) { + if (aggregate.IsDistinct() || !aggregate.function.absorb) { // distinct aggregates are not supported in perfect hash aggregates return false; } diff --git a/src/function/scalar/system/aggregate_export.cpp b/src/function/scalar/system/aggregate_export.cpp index e71255384f4..38d138845bd 100644 --- a/src/function/scalar/system/aggregate_export.cpp +++ b/src/function/scalar/system/aggregate_export.cpp @@ -172,7 +172,7 @@ static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Ve memcpy(local_state.state_buffer1.get(), state1.GetData(), bind_data.state_size); AggregateInputData aggr_input_data(nullptr, local_state.allocator); - bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); + bind_data.aggr.absorb(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); result_ptr[i] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(local_state.state_buffer1.get()), bind_data.state_size); @@ -299,7 +299,7 @@ static unique_ptr ExportStateScalarDeserialize(Deserializer &deser unique_ptr ExportAggregateFunction::Bind(unique_ptr child_aggregate) { auto &bound_function = child_aggregate->function; - if (!bound_function.combine) { + if (!bound_function.absorb) { throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name); } if (bound_function.bind) { @@ -329,6 +329,7 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, /* can't propagate statistics */ nullptr, nullptr); + export_function.absorb = bound_function.absorb; export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; export_function.serialize = ExportStateAggregateSerialize; export_function.deserialize = ExportStateAggregateDeserialize; diff --git a/src/include/duckdb/common/vector_operations/aggregate_executor.hpp b/src/include/duckdb/common/vector_operations/aggregate_executor.hpp index 312dd065580..d525e63d8a5 100644 --- a/src/include/duckdb/common/vector_operations/aggregate_executor.hpp +++ b/src/include/duckdb/common/vector_operations/aggregate_executor.hpp @@ -340,6 +340,18 @@ class AggregateExecutor { } } + template + static void Absorb(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { + D_ASSERT(source.GetType().id() == LogicalTypeId::POINTER && target.GetType().id() == LogicalTypeId::POINTER); + // Destructive combine + auto sdata = FlatVector::GetData(source); + auto tdata = FlatVector::GetData(target); + + for (idx_t i = 0; i < count; i++) { + OP::template Absorb(*sdata[i], *tdata[i], aggr_input_data); + } + } + template static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { diff --git a/src/include/duckdb/function/aggregate_function.hpp b/src/include/duckdb/function/aggregate_function.hpp index aff5a71bff6..cb3b8877612 100644 --- a/src/include/duckdb/function/aggregate_function.hpp +++ b/src/include/duckdb/function/aggregate_function.hpp @@ -91,9 +91,10 @@ class AggregateFunction : public BaseScalarFunction { aggregate_deserialize_t deserialize = nullptr) : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, LogicalType(LogicalTypeId::INVALID), null_handling), - state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), - simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), - serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + state_size(state_size), initialize(initialize), update(update), combine(combine), absorb(combine), + finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), + statistics(statistics), serialize(serialize), deserialize(deserialize), + order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, @@ -105,9 +106,10 @@ class AggregateFunction : public BaseScalarFunction { aggregate_deserialize_t deserialize = nullptr) : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, LogicalType(LogicalTypeId::INVALID)), - state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), - simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), - serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + state_size(state_size), initialize(initialize), update(update), combine(combine), absorb(combine), + finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), + statistics(statistics), serialize(serialize), deserialize(deserialize), + order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, @@ -139,8 +141,10 @@ class AggregateFunction : public BaseScalarFunction { aggregate_initialize_t initialize; //! The hashed aggregate update state function aggregate_update_t update; - //! The hashed aggregate combine states function + //! The non-destructive hashed aggregate combine states function aggregate_combine_t combine; + //! The (possibly) destructive hashed aggregate combine states function (may be == combine) + aggregate_combine_t absorb; //! The hashed aggregate finalization function aggregate_finalize_t finalize; //! The simple aggregate update function (may be null) @@ -278,6 +282,11 @@ class AggregateFunction : public BaseScalarFunction { AggregateExecutor::Combine(source, target, aggr_input_data, count); } + template + static void StateAbsorb(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { + AggregateExecutor::Absorb(source, target, aggr_input_data, count); + } + template static void StateFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { From ed6b6e3f38b9e85652b531de8e925e07e1065551 Mon Sep 17 00:00:00 2001 From: Richard Wesley Date: Tue, 12 Dec 2023 14:03:25 -0800 Subject: [PATCH 2/8] Internal #861: Aggregation Absorb API * Implement absorb for LIST and remove custom window hack * Add test for LIST segment trees and verify it doesn't trash anything * Fix a small API wart in the list segment code. --- src/common/types/list_segment.cpp | 4 +- src/core_functions/aggregate/nested/list.cpp | 66 +++++++--------- .../duckdb/common/types/list_segment.hpp | 2 +- test/sql/window/test_empty_frames.test | 4 +- test/sql/window/test_list_window.test | 77 +++++++++++++++++++ test/sql/window/test_window_exclude.test | 6 +- 6 files changed, 110 insertions(+), 49 deletions(-) diff --git a/src/common/types/list_segment.cpp b/src/common/types/list_segment.cpp index 6627383af4d..602fb2abb36 100644 --- a/src/common/types/list_segment.cpp +++ b/src/common/types/list_segment.cpp @@ -523,10 +523,8 @@ static void ReadDataFromArraySegment(const ListSegmentFunctions &functions, cons functions.child_functions[0].BuildListVector(linked_child_list, child_vector, child_size); } -void ListSegmentFunctions::BuildListVector(const LinkedList &linked_list, Vector &result, - idx_t &initial_total_count) const { +void ListSegmentFunctions::BuildListVector(const LinkedList &linked_list, Vector &result, idx_t total_count) const { auto &read_data_from_segment = *this; - idx_t total_count = initial_total_count; auto segment = linked_list.first_segment; while (segment) { read_data_from_segment.read_data(read_data_from_segment, segment, result, total_count); diff --git a/src/core_functions/aggregate/nested/list.cpp b/src/core_functions/aggregate/nested/list.cpp index e9a39593c5f..6e2c655625a 100644 --- a/src/core_functions/aggregate/nested/list.cpp +++ b/src/core_functions/aggregate/nested/list.cpp @@ -67,7 +67,7 @@ static void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_d } } -static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &, idx_t count) { +static void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputData &, idx_t count) { UnifiedVectorFormat states_data; states_vector.ToUnifiedFormat(count, states_data); @@ -147,49 +147,32 @@ static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_d ListVector::SetListSize(result, total_len); } -static void ListWindow(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, - idx_t rid) { +static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, + idx_t count) { - auto &list_bind_data = aggr_input_data.bind_data->Cast(); - LinkedList linked_list; + UnifiedVectorFormat states_data; + states_vector.ToUnifiedFormat(count, states_data); + auto states_ptr = UnifiedVectorFormat::GetData(states_data); + auto combined_ptr = FlatVector::GetData(combined); - // UPDATE step + auto &list_bind_data = aggr_input_data.bind_data->Cast(); + auto result_type = ListType::GetChildType(list_bind_data.stype); - D_ASSERT(partition.input_count == 1); - // FIXME: We are modifying the window operator's data here - auto &input = const_cast(partition.inputs[0]); + for (idx_t i = 0; i < count; i++) { + auto &source = *states_ptr[states_data.sel->get_index(i)]; + auto &target = *combined_ptr[i]; - // FIXME: we unify more values than necessary (count is frame.end) - const auto count = frames.back().end; + const auto entry_count = source.linked_list.total_capacity; + Vector input(result_type, source.linked_list.total_capacity); + list_bind_data.functions.BuildListVector(source.linked_list, input, 0); - RecursiveUnifiedVectorFormat input_data; - Vector::RecursiveToUnifiedFormat(input, count, input_data); + RecursiveUnifiedVectorFormat input_data; + Vector::RecursiveToUnifiedFormat(input, entry_count, input_data); - for (const auto &frame : frames) { - for (idx_t i = frame.start; i < frame.end; i++) { - list_bind_data.functions.AppendRow(aggr_input_data.allocator, linked_list, input_data, i); + for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { + list_bind_data.functions.AppendRow(aggr_input_data.allocator, target.linked_list, input_data, entry_idx); } } - - // FINALIZE step - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto result_data = FlatVector::GetData(result); - size_t total_len = ListVector::GetListSize(result); - - // set the length and offset of this list in the result vector - result_data[rid].offset = total_len; - result_data[rid].length = linked_list.total_capacity; - total_len += linked_list.total_capacity; - - // reserve capacity, then copy over the data to the child vector - ListVector::Reserve(result, total_len); - auto &result_child = ListVector::GetEntry(result); - idx_t offset = result_data[rid].offset; - list_bind_data.functions.BuildListVector(linked_list, result_child, offset); - - ListVector::SetListSize(result, total_len); } unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, @@ -208,10 +191,13 @@ unique_ptr ListBindFunction(ClientContext &context, AggregateFunct } AggregateFunction ListFun::GetFunction() { - return AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, ListUpdateFunction, - ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, - ListWindow); + auto func = + AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, + AggregateFunction::StateInitialize, ListUpdateFunction, + ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, nullptr); + func.absorb = ListAbsorbFunction; + + return func; } } // namespace duckdb diff --git a/src/include/duckdb/common/types/list_segment.hpp b/src/include/duckdb/common/types/list_segment.hpp index ea4c2ad89f0..01c9886c5fc 100644 --- a/src/include/duckdb/common/types/list_segment.hpp +++ b/src/include/duckdb/common/types/list_segment.hpp @@ -51,7 +51,7 @@ struct ListSegmentFunctions { void AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const; - void BuildListVector(const LinkedList &linked_list, Vector &result, idx_t &initial_total_count) const; + void BuildListVector(const LinkedList &linked_list, Vector &result, idx_t initial_total_count) const; }; void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type); diff --git a/test/sql/window/test_empty_frames.test b/test/sql/window/test_empty_frames.test index 537c248d47f..250f007245c 100644 --- a/test/sql/window/test_empty_frames.test +++ b/test/sql/window/test_empty_frames.test @@ -105,9 +105,9 @@ SELECT id, ${agg}(id) OVER (PARTITION BY ch ORDER BY id ROWS BETWEEN 1 FOLLOWING FROM t1 ORDER BY 1; ---- -1 [] +1 NULL 2 [NULL] -NULL [] +NULL NULL endloop diff --git a/test/sql/window/test_list_window.test b/test/sql/window/test_list_window.test index a335aeafff8..d39b8cb8184 100644 --- a/test/sql/window/test_list_window.test +++ b/test/sql/window/test_list_window.test @@ -48,3 +48,80 @@ SELECT FIRST(LIST_EXTRACT(l, 3)) FROM list_window GROUP BY g ORDER BY g; NULL NULL NULL + +statement ok +create table list_combine_test as + select range%3 j, + range::varchar AS s, + case when range%3=0 then '-' else '|' end sep + from range(1, 65) + +query III +select j, s, list(s) over (partition by j order by s) +from list_combine_test +order by j, s; +---- +0 12 [12] +0 15 [12, 15] +0 18 [12, 15, 18] +0 21 [12, 15, 18, 21] +0 24 [12, 15, 18, 21, 24] +0 27 [12, 15, 18, 21, 24, 27] +0 3 [12, 15, 18, 21, 24, 27, 3] +0 30 [12, 15, 18, 21, 24, 27, 3, 30] +0 33 [12, 15, 18, 21, 24, 27, 3, 30, 33] +0 36 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36] +0 39 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39] +0 42 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42] +0 45 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45] +0 48 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48] +0 51 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48, 51] +0 54 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48, 51, 54] +0 57 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57] +0 6 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 6] +0 60 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 6, 60] +0 63 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 6, 60, 63] +0 9 [12, 15, 18, 21, 24, 27, 3, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 6, 60, 63, 9] +1 1 [1] +1 10 [1, 10] +1 13 [1, 10, 13] +1 16 [1, 10, 13, 16] +1 19 [1, 10, 13, 16, 19] +1 22 [1, 10, 13, 16, 19, 22] +1 25 [1, 10, 13, 16, 19, 22, 25] +1 28 [1, 10, 13, 16, 19, 22, 25, 28] +1 31 [1, 10, 13, 16, 19, 22, 25, 28, 31] +1 34 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34] +1 37 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37] +1 4 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4] +1 40 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40] +1 43 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43] +1 46 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46] +1 49 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46, 49] +1 52 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46, 49, 52] +1 55 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46, 49, 52, 55] +1 58 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46, 49, 52, 55, 58] +1 61 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46, 49, 52, 55, 58, 61] +1 64 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46, 49, 52, 55, 58, 61, 64] +1 7 [1, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 4, 40, 43, 46, 49, 52, 55, 58, 61, 64, 7] +2 11 [11] +2 14 [11, 14] +2 17 [11, 14, 17] +2 2 [11, 14, 17, 2] +2 20 [11, 14, 17, 2, 20] +2 23 [11, 14, 17, 2, 20, 23] +2 26 [11, 14, 17, 2, 20, 23, 26] +2 29 [11, 14, 17, 2, 20, 23, 26, 29] +2 32 [11, 14, 17, 2, 20, 23, 26, 29, 32] +2 35 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35] +2 38 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38] +2 41 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41] +2 44 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44] +2 47 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47] +2 5 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 5] +2 50 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 5, 50] +2 53 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 5, 50, 53] +2 56 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 5, 50, 53, 56] +2 59 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 5, 50, 53, 56, 59] +2 62 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 5, 50, 53, 56, 59, 62] +2 8 [11, 14, 17, 2, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47, 5, 50, 53, 56, 59, 62, 8] diff --git a/test/sql/window/test_window_exclude.test b/test/sql/window/test_window_exclude.test index 66d9ea434f4..90417317d38 100644 --- a/test/sql/window/test_window_exclude.test +++ b/test/sql/window/test_window_exclude.test @@ -531,7 +531,7 @@ FROM ( WINDOW w AS (ORDER BY i ROWS UNBOUNDED PRECEDING EXCLUDE CURRENT ROW) ORDER BY i; ---- -1 [] +1 NULL 1 [1] 2 [1, 1] 2 [1, 1, 2] @@ -553,8 +553,8 @@ FROM ( WINDOW w AS (ORDER BY i ROWS UNBOUNDED PRECEDING EXCLUDE GROUP) ORDER BY i; ---- -1 [] -1 [] +1 NULL +1 NULL 2 [1, 1] 2 [1, 1] 3 [1, 1, 2, 2] From cfe756d2ca1af8af94bdef566bbd6369766325aa Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Fri, 15 Dec 2023 07:04:42 -0800 Subject: [PATCH 3/8] Internal #861: Aggregation Absorb API PR feedback: * Change header variable name * Document new API in the README. --- src/core_functions/aggregate/README.md | 26 ++++++++++++++++++- .../duckdb/common/types/list_segment.hpp | 2 +- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/core_functions/aggregate/README.md b/src/core_functions/aggregate/README.md index f600e2ce338..8544a2ea0ec 100644 --- a/src/core_functions/aggregate/README.md +++ b/src/core_functions/aggregate/README.md @@ -25,6 +25,7 @@ Unlike simple scalar functions, there are several of these: | `update` | Accumulate the arguments into the corresponding `State` | X | | `simple_update` | Accumulate the arguments into a single `State`. | | | `combine` | Merge one `State` into another | | +| `absorb` | Destructively merge one `State` into another | | | `finalize` | Convert a `State` into a final value. | X | | `window` | Compute a windowed aggregate value from the inputs and frame bounds | | | `bind` | Modify the binding of the aggregate | | @@ -129,11 +130,34 @@ Combine(const State& source, State &target, AggregateInputData &info) Note that the `sources` should _not_ be modified for efficiency because the caller may be using them for multiple operations(e.g., window segment trees). -If you wish to combine destructively, you _must_ define a `window` function. +If you wish to combine destructively, you _must_ define an `absorb` function. The `combine` operation is optional, but it is needed for multi-threaded aggregation. If it is not provided, then _all_ aggregate functions in the grouping must be computed on a single thread. +### Absorb + +```cpp +absorb(Vector &sources, Vector &targets, AggregateInputData &info, idx_t count) +``` + +Merges the source states into the corresponding target states. +If you are using template generators, +the generator is `StateAbsorb` and the method it wraps is: + +```cpp +Absorb(State& source, State &target, AggregateInputData &info) +``` +Absorb should be defined when the aggregate can move data more efficiently than copying it. +`LIST` is an example, where the internal linked list data structures can be + +`absorb` is optional and defaults to `combine`, but it is called in situations +where destructively moving the source data is allowed because the caller no longer needs the source +(typically for `GROUP BY` operations). +The source still needs to allow the destructor (if any) to be called, +and a separate, non-destructive `combine` operation _must_ to be defined for use by +windowing accelerators. + ### Finalize ```cpp diff --git a/src/include/duckdb/common/types/list_segment.hpp b/src/include/duckdb/common/types/list_segment.hpp index 01c9886c5fc..79f359faff1 100644 --- a/src/include/duckdb/common/types/list_segment.hpp +++ b/src/include/duckdb/common/types/list_segment.hpp @@ -51,7 +51,7 @@ struct ListSegmentFunctions { void AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const; - void BuildListVector(const LinkedList &linked_list, Vector &result, idx_t initial_total_count) const; + void BuildListVector(const LinkedList &linked_list, Vector &result, idx_t total_count) const; }; void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type); From 3e991b596b8af189aa834502b087b0c72492161d Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Wed, 20 Dec 2023 14:53:23 -0800 Subject: [PATCH 4/8] Internal #861: Aggregation Absorb API First refactor of segment trees for order-sensitive aggregates. --- src/execution/window_segment_tree.cpp | 66 +++++++++++++++++++-------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/src/execution/window_segment_tree.cpp b/src/execution/window_segment_tree.cpp index d6fcb059def..9f7f860f451 100644 --- a/src/execution/window_segment_tree.cpp +++ b/src/execution/window_segment_tree.cpp @@ -653,7 +653,7 @@ class WindowSegmentTreePart { void ExtractFrame(idx_t begin, idx_t end, data_ptr_t current_state); void WindowSegmentValue(const WindowSegmentTree &tree, idx_t l_idx, idx_t begin, idx_t end, data_ptr_t current_state); - //! optionally writes result and calls destructors + //! Writes result and calls destructors void Finalize(Vector &result, idx_t count); void Combine(WindowSegmentTreePart &other, idx_t count); @@ -661,6 +661,15 @@ class WindowSegmentTreePart { void Evaluate(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, Vector &result, idx_t count, idx_t row_idx, FramePart frame_part); +protected: + //! Initialises the accumulation state vector (statef) + void Initialize(idx_t count); + //! Accumulate upper tree levels + void EvaluateUpperLevels(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, idx_t count, + idx_t row_idx, FramePart frame_part); + void EvaluateLeaves(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, idx_t count, + idx_t row_idx, FramePart frame_part); + public: //! Allocator for aggregates ArenaAllocator &allocator; @@ -898,20 +907,23 @@ void WindowSegmentTree::Evaluate(WindowAggregatorState &lstate, const DataChunk part.Finalize(result, count); } -void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, - Vector &result, idx_t count, idx_t row_idx, FramePart frame_part) { +void WindowSegmentTreePart::Initialize(idx_t count) { + auto fdata = FlatVector::GetData(statef); + for (idx_t rid = 0; rid < count; ++rid) { + auto state_ptr = fdata[rid]; + aggr.function.initialize(state_ptr); + } +} - const auto cant_combine = (!aggr.function.combine || !tree.UseCombineAPI()); +void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, + idx_t count, idx_t row_idx, FramePart frame_part) { auto fdata = FlatVector::GetData(statef); const auto exclude_mode = tree.exclude_mode; const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; - // with EXCLUDE TIES, in addition to the frame part right of the peer group's end, we also need to consider the - // current row - const bool add_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::TIES; + const bool can_reuse = aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT; - // First pass: aggregate the segment tree nodes // Share adjacent identical states // We do this first because we want to share only tree aggregations idx_t prev_begin = 1; @@ -921,12 +933,6 @@ void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t data_ptr_t prev_state = nullptr; for (idx_t rid = 0, cur_row = row_idx; rid < count; ++rid, ++cur_row) { auto state_ptr = fdata[rid]; - aggr.function.initialize(state_ptr); - - if (cant_combine) { - // Make sure we initialise all states - continue; - } auto begin = begin_on_curr_row ? cur_row + 1 : begins[rid]; auto end = end_on_curr_row ? cur_row : ends[rid]; @@ -949,7 +955,7 @@ void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t break; } - if (l_idx == 1) { + if (can_reuse && l_idx == 1) { prev_state = state_ptr; prev_begin = begin; prev_end = end; @@ -979,9 +985,19 @@ void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t } } FlushStates(true); +} + +void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, + idx_t count, idx_t row_idx, FramePart frame_part) { + auto fdata = FlatVector::GetData(statef); + + const auto exclude_mode = tree.exclude_mode; + const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; + const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; + // with EXCLUDE TIES, in addition to the frame part right of the peer group's end, we also need to consider the + // current row + const bool add_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::TIES; - // Second pass: aggregate the ragged leaves - // (or everything if we can't combine) for (idx_t rid = 0, cur_row = row_idx; rid < count; ++rid, ++cur_row) { auto state_ptr = fdata[rid]; @@ -994,10 +1010,9 @@ void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t continue; } - // Aggregate everything at once if we can't combine states idx_t parent_begin = begin / tree.TREE_FANOUT; idx_t parent_end = end / tree.TREE_FANOUT; - if (parent_begin == parent_end || cant_combine) { + if (parent_begin == parent_end) { WindowSegmentValue(tree, 0, begin, end, state_ptr); continue; } @@ -1015,6 +1030,19 @@ void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t FlushStates(false); } +void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, + Vector &result, idx_t count, idx_t row_idx, FramePart frame_part) { + D_ASSERT(aggr.function.combine && tree.UseCombineAPI()); + + Initialize(count); + + // First pass: aggregate the segment tree nodes + EvaluateUpperLevels(tree, begins, ends, count, row_idx, frame_part); + + // Second pass: aggregate the ragged leaves + EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part); +} + //===--------------------------------------------------------------------===// // WindowDistinctAggregator //===--------------------------------------------------------------------===// From 08d282a2cf89d1ea5473f1426040e6896fadc105 Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Thu, 21 Dec 2023 09:49:34 -0800 Subject: [PATCH 5/8] Internal #861: Aggregation Absorb API Handle order-sensitive aggregates segment trees. --- src/execution/window_segment_tree.cpp | 100 ++++++++++++++++++-------- 1 file changed, 72 insertions(+), 28 deletions(-) diff --git a/src/execution/window_segment_tree.cpp b/src/execution/window_segment_tree.cpp index 9f7f860f451..72fd812f5bc 100644 --- a/src/execution/window_segment_tree.cpp +++ b/src/execution/window_segment_tree.cpp @@ -639,6 +639,9 @@ WindowSegmentTree::~WindowSegmentTree() { class WindowSegmentTreePart { public: + //! Right side nodes need to be cached and processed in reverse order + using RightEntry = std::pair; + enum FramePart : uint8_t { FULL = 0, LEFT = 1, RIGHT = 2 }; WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, DataChunk &inputs, @@ -668,14 +671,16 @@ class WindowSegmentTreePart { void EvaluateUpperLevels(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, idx_t count, idx_t row_idx, FramePart frame_part); void EvaluateLeaves(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, idx_t count, - idx_t row_idx, FramePart frame_part); + idx_t row_idx, FramePart frame_part, FramePart leaf_part); public: //! Allocator for aggregates ArenaAllocator &allocator; //! The aggregate function const AggregateObject &aggr; - //! The aggregate function + //! Order insensitive aggregate (we can optimise internal combines) + const bool order_insensitive; + //! The partition arguments DataChunk &inputs; //! The filtered rows in inputs const ValidityMask &filter_mask; @@ -695,6 +700,8 @@ class WindowSegmentTreePart { Vector statef; //! Count of buffered values idx_t flush_count; + //! Cache of right side tree ranges for ordered aggregates + vector right_stack; }; class WindowSegmentTreeState : public WindowAggregatorState { @@ -717,9 +724,10 @@ class WindowSegmentTreeState : public WindowAggregatorState { WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, DataChunk &inputs, const ValidityMask &filter_mask) - : allocator(allocator), aggr(aggr), inputs(inputs), filter_mask(filter_mask), - state_size(aggr.function.state_size()), state(state_size * STANDARD_VECTOR_SIZE), statep(LogicalType::POINTER), - statel(LogicalType::POINTER), statef(LogicalType::POINTER), flush_count(0) { + : allocator(allocator), aggr(aggr), + order_insensitive(aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT), inputs(inputs), + filter_mask(filter_mask), state_size(aggr.function.state_size()), state(state_size * STANDARD_VECTOR_SIZE), + statep(LogicalType::POINTER), statel(LogicalType::POINTER), statef(LogicalType::POINTER), flush_count(0) { if (inputs.ColumnCount() > 0) { leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); filter_sel.Initialize(); @@ -907,6 +915,26 @@ void WindowSegmentTree::Evaluate(WindowAggregatorState &lstate, const DataChunk part.Finalize(result, count); } +void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, + Vector &result, idx_t count, idx_t row_idx, FramePart frame_part) { + D_ASSERT(aggr.function.combine && tree.UseCombineAPI()); + + Initialize(count); + + if (order_insensitive) { + // First pass: aggregate the segment tree nodes with sharing + EvaluateUpperLevels(tree, begins, ends, count, row_idx, frame_part); + + // Second pass: aggregate the ragged leaves + EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part, FramePart::FULL); + } else { + // Evaluate leaves in order + EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part, FramePart::LEFT); + EvaluateUpperLevels(tree, begins, ends, count, row_idx, frame_part); + EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part, FramePart::RIGHT); + } +} + void WindowSegmentTreePart::Initialize(idx_t count) { auto fdata = FlatVector::GetData(statef); for (idx_t rid = 0; rid < count; ++rid) { @@ -922,7 +950,9 @@ void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTree &tree, c const auto exclude_mode = tree.exclude_mode; const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; - const bool can_reuse = aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT; + + const auto max_level = tree.levels_flat_start.size() + 1; + right_stack.resize(max_level, {0, 0}); // Share adjacent identical states // We do this first because we want to share only tree aggregations @@ -942,7 +972,8 @@ void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTree &tree, c // Skip level 0 idx_t l_idx = 0; - for (; l_idx < tree.levels_flat_start.size() + 1; l_idx++) { + idx_t right_max = 0; + for (; l_idx < max_level; l_idx++) { idx_t parent_begin = begin / tree.TREE_FANOUT; idx_t parent_end = end / tree.TREE_FANOUT; if (prev_state && l_idx == 1 && begin == prev_begin && end == prev_end) { @@ -955,7 +986,7 @@ void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTree &tree, c break; } - if (can_reuse && l_idx == 1) { + if (order_insensitive && l_idx == 1) { prev_state = state_ptr; prev_begin = begin; prev_end = end; @@ -977,26 +1008,51 @@ void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTree &tree, c idx_t group_end = parent_end * tree.TREE_FANOUT; if (end != group_end) { if (l_idx) { - WindowSegmentValue(tree, l_idx, group_end, end, state_ptr); + if (order_insensitive) { + WindowSegmentValue(tree, l_idx, group_end, end, state_ptr); + } else { + right_stack[l_idx] = {group_end, end}; + right_max = l_idx; + } } } begin = parent_begin; end = parent_end; } + + // Flush the right side values from left to right for order_sensitive aggregates + // As we go up the tree, the right side ranges move left, + // so we just cache them in a fixed size, preallocated array. + // Then we can just reverse scan the array and append the cached ranges. + for (l_idx = right_max; l_idx > 0; --l_idx) { + auto &right_entry = right_stack[l_idx]; + const auto group_end = right_entry.first; + const auto end = right_entry.second; + if (end) { + WindowSegmentValue(tree, l_idx, group_end, end, state_ptr); + right_entry = {0, 0}; + } + } } FlushStates(true); } void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, - idx_t count, idx_t row_idx, FramePart frame_part) { + idx_t count, idx_t row_idx, FramePart frame_part, FramePart leaf_part) { + auto fdata = FlatVector::GetData(statef); + // For order-sensitive aggregates, we have to process the ragged leaves in two pieces. + // The left side have to be added before the main tree followed by the ragged right sides. + // The current row is the leftmost value of the right hand side. + const bool compute_left = leaf_part != FramePart::RIGHT; + const bool compute_right = leaf_part != FramePart::LEFT; const auto exclude_mode = tree.exclude_mode; const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; // with EXCLUDE TIES, in addition to the frame part right of the peer group's end, we also need to consider the // current row - const bool add_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::TIES; + const bool add_curr_row = compute_left && frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::TIES; for (idx_t rid = 0, cur_row = row_idx; rid < count; ++rid, ++cur_row) { auto state_ptr = fdata[rid]; @@ -1013,36 +1069,24 @@ void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTree &tree, const idx_t parent_begin = begin / tree.TREE_FANOUT; idx_t parent_end = end / tree.TREE_FANOUT; if (parent_begin == parent_end) { - WindowSegmentValue(tree, 0, begin, end, state_ptr); + if (compute_left) { + WindowSegmentValue(tree, 0, begin, end, state_ptr); + } continue; } idx_t group_begin = parent_begin * tree.TREE_FANOUT; - if (begin != group_begin) { + if (begin != group_begin && compute_left) { WindowSegmentValue(tree, 0, begin, group_begin + tree.TREE_FANOUT, state_ptr); - parent_begin++; } idx_t group_end = parent_end * tree.TREE_FANOUT; - if (end != group_end) { + if (end != group_end && compute_right) { WindowSegmentValue(tree, 0, group_end, end, state_ptr); } } FlushStates(false); } -void WindowSegmentTreePart::Evaluate(const WindowSegmentTree &tree, const idx_t *begins, const idx_t *ends, - Vector &result, idx_t count, idx_t row_idx, FramePart frame_part) { - D_ASSERT(aggr.function.combine && tree.UseCombineAPI()); - - Initialize(count); - - // First pass: aggregate the segment tree nodes - EvaluateUpperLevels(tree, begins, ends, count, row_idx, frame_part); - - // Second pass: aggregate the ragged leaves - EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part); -} - //===--------------------------------------------------------------------===// // WindowDistinctAggregator //===--------------------------------------------------------------------===// From 4700e9adad76e9f3866c38034c590e72f9700aa7 Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Mon, 8 Jan 2024 14:21:37 -0800 Subject: [PATCH 6/8] Internal #861: Aggregation Absorb Functionality Use an enum in AggregateInputData instead of creating a new API. --- src/common/row_operations/row_aggregate.cpp | 7 +++--- src/core_functions/aggregate/nested/list.cpp | 11 +++++++-- .../aggregate/grouped_aggregate_data.cpp | 4 ++-- .../physical_perfecthash_aggregate.cpp | 2 +- .../physical_ungrouped_aggregate.cpp | 10 ++++---- .../physical_plan/plan_aggregate.cpp | 2 +- .../scalar/system/aggregate_export.cpp | 7 +++--- .../duckdb/function/aggregate_function.hpp | 23 ++++++------------- .../duckdb/function/aggregate_state.hpp | 8 +++++-- 9 files changed, 39 insertions(+), 35 deletions(-) diff --git a/src/common/row_operations/row_aggregate.cpp b/src/common/row_operations/row_aggregate.cpp index 96dfbb6f55a..f6e9e6cbb37 100644 --- a/src/common/row_operations/row_aggregate.cpp +++ b/src/common/row_operations/row_aggregate.cpp @@ -81,9 +81,10 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la idx_t offset = layout.GetAggrOffset(); for (auto &aggr : layout.GetAggregates()) { - D_ASSERT(aggr.function.absorb); - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.absorb(sources, targets, aggr_input_data, count); + D_ASSERT(aggr.function.combine); + AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator, + AggregateCombineType::ALLOW_DESTRUCTIVE); + aggr.function.combine(sources, targets, aggr_input_data, count); // Move to the next aggregate states VectorOperations::AddInPlace(sources, aggr.payload_size, count); diff --git a/src/core_functions/aggregate/nested/list.cpp b/src/core_functions/aggregate/nested/list.cpp index 6e2c655625a..1895cd27bd5 100644 --- a/src/core_functions/aggregate/nested/list.cpp +++ b/src/core_functions/aggregate/nested/list.cpp @@ -67,7 +67,9 @@ static void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_d } } -static void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputData &, idx_t count) { +static void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, + idx_t count) { + D_ASSERT(aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE); UnifiedVectorFormat states_data; states_vector.ToUnifiedFormat(count, states_data); @@ -150,6 +152,12 @@ static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_d static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, idx_t count) { + // Can we use destructive combining? + if (aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE) { + ListAbsorbFunction(states_vector, combined, aggr_input_data, count); + return; + } + UnifiedVectorFormat states_data; states_vector.ToUnifiedFormat(count, states_data); auto states_ptr = UnifiedVectorFormat::GetData(states_data); @@ -195,7 +203,6 @@ AggregateFunction ListFun::GetFunction() { AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, AggregateFunction::StateInitialize, ListUpdateFunction, ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, nullptr); - func.absorb = ListAbsorbFunction; return func; } diff --git a/src/execution/operator/aggregate/grouped_aggregate_data.cpp b/src/execution/operator/aggregate/grouped_aggregate_data.cpp index fe74c8c57e4..d0b52a606bd 100644 --- a/src/execution/operator/aggregate/grouped_aggregate_data.cpp +++ b/src/execution/operator/aggregate/grouped_aggregate_data.cpp @@ -33,7 +33,7 @@ void GroupedAggregateData::InitializeGroupby(vector> grou filter_count++; payload_types_filters.push_back(aggr.filter->return_type); } - if (!aggr.function.absorb) { + if (!aggr.function.combine) { throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); } aggregates.push_back(std::move(expr)); @@ -63,7 +63,7 @@ void GroupedAggregateData::InitializeDistinct(const unique_ptr &aggr filter_count++; } } - if (!aggr.function.absorb) { + if (!aggr.function.combine) { throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); } } diff --git a/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp b/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp index 5553986f563..fe7e6c46163 100644 --- a/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp @@ -35,7 +35,7 @@ PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &contex bindings.push_back(&aggr); D_ASSERT(!aggr.IsDistinct()); - D_ASSERT(aggr.function.absorb); + D_ASSERT(aggr.function.combine); for (auto &child : aggr.children) { payload_types.push_back(child->return_type); } diff --git a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index 4c4fa750f77..73dcc7c47a9 100644 --- a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -354,8 +354,9 @@ SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &cont Vector source_state(Value::POINTER(CastPointerToValue(lstate.state.aggregates[aggr_idx].get()))); Vector dest_state(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get()))); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator); - aggregate.function.absorb(source_state, dest_state, aggr_input_data, 1); + AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator, + AggregateCombineType::ALLOW_DESTRUCTIVE); + aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); #ifdef DEBUG gstate.state.counts[aggr_idx] += lstate.state.counts[aggr_idx]; #endif @@ -548,11 +549,12 @@ void UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() { } auto &aggregate = aggregates[agg_idx]->Cast(); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); + AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, + AggregateCombineType::ALLOW_DESTRUCTIVE); Vector state_vec(Value::POINTER(CastPointerToValue(state.aggregates[agg_idx].get()))); Vector combined_vec(Value::POINTER(CastPointerToValue(gstate.state.aggregates[agg_idx].get()))); - aggregate.function.absorb(state_vec, combined_vec, aggr_input_data, 1); + aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); } D_ASSERT(!gstate.finished); diff --git a/src/execution/physical_plan/plan_aggregate.cpp b/src/execution/physical_plan/plan_aggregate.cpp index 4f84d3b727b..4160783e2ad 100644 --- a/src/execution/physical_plan/plan_aggregate.cpp +++ b/src/execution/physical_plan/plan_aggregate.cpp @@ -141,7 +141,7 @@ static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate } for (auto &expression : op.expressions) { auto &aggregate = expression->Cast(); - if (aggregate.IsDistinct() || !aggregate.function.absorb) { + if (aggregate.IsDistinct() || !aggregate.function.combine) { // distinct aggregates are not supported in perfect hash aggregates return false; } diff --git a/src/function/scalar/system/aggregate_export.cpp b/src/function/scalar/system/aggregate_export.cpp index 38d138845bd..e6adcee0893 100644 --- a/src/function/scalar/system/aggregate_export.cpp +++ b/src/function/scalar/system/aggregate_export.cpp @@ -171,8 +171,8 @@ static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Ve memcpy(local_state.state_buffer0.get(), state0.GetData(), bind_data.state_size); memcpy(local_state.state_buffer1.get(), state1.GetData(), bind_data.state_size); - AggregateInputData aggr_input_data(nullptr, local_state.allocator); - bind_data.aggr.absorb(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); + AggregateInputData aggr_input_data(nullptr, local_state.allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); + bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); result_ptr[i] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(local_state.state_buffer1.get()), bind_data.state_size); @@ -299,7 +299,7 @@ static unique_ptr ExportStateScalarDeserialize(Deserializer &deser unique_ptr ExportAggregateFunction::Bind(unique_ptr child_aggregate) { auto &bound_function = child_aggregate->function; - if (!bound_function.absorb) { + if (!bound_function.combine) { throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name); } if (bound_function.bind) { @@ -329,7 +329,6 @@ ExportAggregateFunction::Bind(unique_ptr child_aggrega bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, /* can't propagate statistics */ nullptr, nullptr); - export_function.absorb = bound_function.absorb; export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; export_function.serialize = ExportStateAggregateSerialize; export_function.deserialize = ExportStateAggregateDeserialize; diff --git a/src/include/duckdb/function/aggregate_function.hpp b/src/include/duckdb/function/aggregate_function.hpp index cb3b8877612..aff5a71bff6 100644 --- a/src/include/duckdb/function/aggregate_function.hpp +++ b/src/include/duckdb/function/aggregate_function.hpp @@ -91,10 +91,9 @@ class AggregateFunction : public BaseScalarFunction { aggregate_deserialize_t deserialize = nullptr) : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, LogicalType(LogicalTypeId::INVALID), null_handling), - state_size(state_size), initialize(initialize), update(update), combine(combine), absorb(combine), - finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), - statistics(statistics), serialize(serialize), deserialize(deserialize), - order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), + simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), + serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const string &name, const vector &arguments, const LogicalType &return_type, @@ -106,10 +105,9 @@ class AggregateFunction : public BaseScalarFunction { aggregate_deserialize_t deserialize = nullptr) : BaseScalarFunction(name, arguments, return_type, FunctionSideEffects::NO_SIDE_EFFECTS, LogicalType(LogicalTypeId::INVALID)), - state_size(state_size), initialize(initialize), update(update), combine(combine), absorb(combine), - finalize(finalize), simple_update(simple_update), window(window), bind(bind), destructor(destructor), - statistics(statistics), serialize(serialize), deserialize(deserialize), - order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { + state_size(state_size), initialize(initialize), update(update), combine(combine), finalize(finalize), + simple_update(simple_update), window(window), bind(bind), destructor(destructor), statistics(statistics), + serialize(serialize), deserialize(deserialize), order_dependent(AggregateOrderDependent::ORDER_DEPENDENT) { } AggregateFunction(const vector &arguments, const LogicalType &return_type, aggregate_size_t state_size, @@ -141,10 +139,8 @@ class AggregateFunction : public BaseScalarFunction { aggregate_initialize_t initialize; //! The hashed aggregate update state function aggregate_update_t update; - //! The non-destructive hashed aggregate combine states function + //! The hashed aggregate combine states function aggregate_combine_t combine; - //! The (possibly) destructive hashed aggregate combine states function (may be == combine) - aggregate_combine_t absorb; //! The hashed aggregate finalization function aggregate_finalize_t finalize; //! The simple aggregate update function (may be null) @@ -282,11 +278,6 @@ class AggregateFunction : public BaseScalarFunction { AggregateExecutor::Combine(source, target, aggr_input_data, count); } - template - static void StateAbsorb(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { - AggregateExecutor::Absorb(source, target, aggr_input_data, count); - } - template static void StateFinalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) { diff --git a/src/include/duckdb/function/aggregate_state.hpp b/src/include/duckdb/function/aggregate_state.hpp index 66b7338c4a5..9b0015d2d66 100644 --- a/src/include/duckdb/function/aggregate_state.hpp +++ b/src/include/duckdb/function/aggregate_state.hpp @@ -17,15 +17,19 @@ namespace duckdb { enum class AggregateType : uint8_t { NON_DISTINCT = 1, DISTINCT = 2 }; //! Whether or not the input order influences the result of the aggregate enum class AggregateOrderDependent : uint8_t { ORDER_DEPENDENT = 1, NOT_ORDER_DEPENDENT = 2 }; +//! Whether or not the combiner needs to preserve the source +enum class AggregateCombineType : uint8_t { PRESERVE_INPUT = 1, ALLOW_DESTRUCTIVE = 2 }; class BoundAggregateExpression; struct AggregateInputData { - AggregateInputData(optional_ptr bind_data_p, ArenaAllocator &allocator_p) - : bind_data(bind_data_p), allocator(allocator_p) { + AggregateInputData(optional_ptr bind_data_p, ArenaAllocator &allocator_p, + AggregateCombineType combine_type_p = AggregateCombineType::PRESERVE_INPUT) + : bind_data(bind_data_p), allocator(allocator_p), combine_type(combine_type_p) { } optional_ptr bind_data; ArenaAllocator &allocator; + AggregateCombineType combine_type; }; struct AggregateUnaryInput { From cb321700027ccbdc72e74ba24197410c31893553 Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Mon, 8 Jan 2024 14:27:55 -0800 Subject: [PATCH 7/8] Internal #861: Aggregation Absorb Functionality Generate AggregateCombineType enum strings... --- src/common/enum_util.cpp | 23 +++++++++++++++++++++++ src/include/duckdb/common/enum_util.hpp | 8 ++++++++ 2 files changed, 31 insertions(+) diff --git a/src/common/enum_util.cpp b/src/common/enum_util.cpp index 650e2ec94e8..dd8cd85560e 100644 --- a/src/common/enum_util.cpp +++ b/src/common/enum_util.cpp @@ -154,6 +154,29 @@ AccessMode EnumUtil::FromString(const char *value) { throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); } +template<> +const char* EnumUtil::ToChars(AggregateCombineType value) { + switch(value) { + case AggregateCombineType::PRESERVE_INPUT: + return "PRESERVE_INPUT"; + case AggregateCombineType::ALLOW_DESTRUCTIVE: + return "ALLOW_DESTRUCTIVE"; + default: + throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value)); + } +} + +template<> +AggregateCombineType EnumUtil::FromString(const char *value) { + if (StringUtil::Equals(value, "PRESERVE_INPUT")) { + return AggregateCombineType::PRESERVE_INPUT; + } + if (StringUtil::Equals(value, "ALLOW_DESTRUCTIVE")) { + return AggregateCombineType::ALLOW_DESTRUCTIVE; + } + throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); +} + template<> const char* EnumUtil::ToChars(AggregateHandling value) { switch(value) { diff --git a/src/include/duckdb/common/enum_util.hpp b/src/include/duckdb/common/enum_util.hpp index 57b876419c0..60285378970 100644 --- a/src/include/duckdb/common/enum_util.hpp +++ b/src/include/duckdb/common/enum_util.hpp @@ -34,6 +34,8 @@ struct EnumUtil { enum class AccessMode : uint8_t; +enum class AggregateCombineType : uint8_t; + enum class AggregateHandling : uint8_t; enum class AggregateOrderDependent : uint8_t; @@ -308,6 +310,9 @@ enum class WithinCollection : uint8_t; template<> const char* EnumUtil::ToChars(AccessMode value); +template<> +const char* EnumUtil::ToChars(AggregateCombineType value); + template<> const char* EnumUtil::ToChars(AggregateHandling value); @@ -717,6 +722,9 @@ const char* EnumUtil::ToChars(WithinCollection value); template<> AccessMode EnumUtil::FromString(const char *value); +template<> +AggregateCombineType EnumUtil::FromString(const char *value); + template<> AggregateHandling EnumUtil::FromString(const char *value); From 6503d144a473e52f3d0fb393478ce94f134bd5dc Mon Sep 17 00:00:00 2001 From: Richard Wesley <13156216+hawkfish@users.noreply.github.com> Date: Tue, 9 Jan 2024 08:21:46 -0800 Subject: [PATCH 8/8] Internal #861: Aggregation Absorb API Remove last references to the destructive combine API called "absorb". --- src/core_functions/aggregate/README.md | 37 ++++--------------- .../vector_operations/aggregate_executor.hpp | 12 ------ 2 files changed, 7 insertions(+), 42 deletions(-) diff --git a/src/core_functions/aggregate/README.md b/src/core_functions/aggregate/README.md index 8544a2ea0ec..8b9fa3477f3 100644 --- a/src/core_functions/aggregate/README.md +++ b/src/core_functions/aggregate/README.md @@ -25,7 +25,6 @@ Unlike simple scalar functions, there are several of these: | `update` | Accumulate the arguments into the corresponding `State` | X | | `simple_update` | Accumulate the arguments into a single `State`. | | | `combine` | Merge one `State` into another | | -| `absorb` | Destructively merge one `State` into another | | | `finalize` | Convert a `State` into a final value. | X | | `window` | Compute a windowed aggregate value from the inputs and frame bounds | | | `bind` | Modify the binding of the aggregate | | @@ -128,36 +127,17 @@ the generator is `StateCombine` and the method it wraps is: Combine(const State& source, State &target, AggregateInputData &info) ``` -Note that the `sources` should _not_ be modified for efficiency because the caller may be using them -for multiple operations(e.g., window segment trees). -If you wish to combine destructively, you _must_ define an `absorb` function. +Note that the `source` should _not_ be modified for efficiency because the caller may be using them +for multiple operations (e.g., window segment trees). + +If you wish to combine destructively, you _must_ check that the `combine_type` member +of the `AggregateInputData` argument is set to `ALLOW_DESTRUCTIVE`. +This is useful when the aggregate can move data more efficiently than copying it. +`LIST` is an example, where the internal linked list data structures can be spliced instead of copied. The `combine` operation is optional, but it is needed for multi-threaded aggregation. If it is not provided, then _all_ aggregate functions in the grouping must be computed on a single thread. -### Absorb - -```cpp -absorb(Vector &sources, Vector &targets, AggregateInputData &info, idx_t count) -``` - -Merges the source states into the corresponding target states. -If you are using template generators, -the generator is `StateAbsorb` and the method it wraps is: - -```cpp -Absorb(State& source, State &target, AggregateInputData &info) -``` -Absorb should be defined when the aggregate can move data more efficiently than copying it. -`LIST` is an example, where the internal linked list data structures can be - -`absorb` is optional and defaults to `combine`, but it is called in situations -where destructively moving the source data is allowed because the caller no longer needs the source -(typically for `GROUP BY` operations). -The source still needs to allow the destructor (if any) to be called, -and a separate, non-destructive `combine` operation _must_ to be defined for use by -windowing accelerators. - ### Finalize ```cpp @@ -208,9 +188,6 @@ Window(const ArgType *arg, ValidityMask &filter, ValidityMask &valid, ResultType &result, idx_t rid, idx_tbias) ``` -Defining `window` is also useful if the aggregate wishes to use a destructive `combine` operation. -This may be tricky to implement efficiently. - ### Bind ```cpp diff --git a/src/include/duckdb/common/vector_operations/aggregate_executor.hpp b/src/include/duckdb/common/vector_operations/aggregate_executor.hpp index d525e63d8a5..312dd065580 100644 --- a/src/include/duckdb/common/vector_operations/aggregate_executor.hpp +++ b/src/include/duckdb/common/vector_operations/aggregate_executor.hpp @@ -340,18 +340,6 @@ class AggregateExecutor { } } - template - static void Absorb(Vector &source, Vector &target, AggregateInputData &aggr_input_data, idx_t count) { - D_ASSERT(source.GetType().id() == LogicalTypeId::POINTER && target.GetType().id() == LogicalTypeId::POINTER); - // Destructive combine - auto sdata = FlatVector::GetData(source); - auto tdata = FlatVector::GetData(target); - - for (idx_t i = 0; i < count; i++) { - OP::template Absorb(*sdata[i], *tdata[i], aggr_input_data); - } - } - template static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, idx_t offset) {