Skip to content

Commit

Permalink
Merge pull request #9971 from hawkfish/absorb
Browse files Browse the repository at this point in the history
Internal #861: Aggregation Absorb API
  • Loading branch information
Mytherin committed Jan 15, 2024
2 parents 83c60cb + 4960dbe commit 41d90ca
Show file tree
Hide file tree
Showing 12 changed files with 262 additions and 92 deletions.
23 changes: 23 additions & 0 deletions src/common/enum_util.cpp
Expand Up @@ -155,6 +155,29 @@ AccessMode EnumUtil::FromString<AccessMode>(const char *value) {
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
}

template<>
const char* EnumUtil::ToChars<AggregateCombineType>(AggregateCombineType value) {
switch(value) {
case AggregateCombineType::PRESERVE_INPUT:
return "PRESERVE_INPUT";
case AggregateCombineType::ALLOW_DESTRUCTIVE:
return "ALLOW_DESTRUCTIVE";
default:
throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value));
}
}

template<>
AggregateCombineType EnumUtil::FromString<AggregateCombineType>(const char *value) {
if (StringUtil::Equals(value, "PRESERVE_INPUT")) {
return AggregateCombineType::PRESERVE_INPUT;
}
if (StringUtil::Equals(value, "ALLOW_DESTRUCTIVE")) {
return AggregateCombineType::ALLOW_DESTRUCTIVE;
}
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
}

template<>
const char* EnumUtil::ToChars<AggregateHandling>(AggregateHandling value) {
switch(value) {
Expand Down
3 changes: 2 additions & 1 deletion src/common/row_operations/row_aggregate.cpp
Expand Up @@ -82,7 +82,8 @@ void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &la

for (auto &aggr : layout.GetAggregates()) {
D_ASSERT(aggr.function.combine);
AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator);
AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator,
AggregateCombineType::ALLOW_DESTRUCTIVE);
aggr.function.combine(sources, targets, aggr_input_data, count);

// Move to the next aggregate states
Expand Down
4 changes: 1 addition & 3 deletions src/common/types/list_segment.cpp
Expand Up @@ -521,10 +521,8 @@ static void ReadDataFromArraySegment(const ListSegmentFunctions &functions, cons
functions.child_functions[0].BuildListVector(linked_child_list, child_vector, child_size);
}

void ListSegmentFunctions::BuildListVector(const LinkedList &linked_list, Vector &result,
idx_t &initial_total_count) const {
void ListSegmentFunctions::BuildListVector(const LinkedList &linked_list, Vector &result, idx_t total_count) const {
auto &read_data_from_segment = *this;
idx_t total_count = initial_total_count;
auto segment = linked_list.first_segment;
while (segment) {
read_data_from_segment.read_data(read_data_from_segment, segment, result, total_count);
Expand Down
13 changes: 7 additions & 6 deletions src/core_functions/aggregate/README.md
Expand Up @@ -127,9 +127,13 @@ the generator is `StateCombine` and the method it wraps is:
Combine(const State& source, State &target, AggregateInputData &info)
```
Note that the `sources` should _not_ be modified for efficiency because the caller may be using them
for multiple operations(e.g., window segment trees).
If you wish to combine destructively, you _must_ define a `window` function.
Note that the `source` should _not_ be modified for efficiency because the caller may be using them
for multiple operations (e.g., window segment trees).
If you wish to combine destructively, you _must_ check that the `combine_type` member
of the `AggregateInputData` argument is set to `ALLOW_DESTRUCTIVE`.
This is useful when the aggregate can move data more efficiently than copying it.
`LIST` is an example, where the internal linked list data structures can be spliced instead of copied.
The `combine` operation is optional, but it is needed for multi-threaded aggregation.
If it is not provided, then _all_ aggregate functions in the grouping must be computed on a single thread.
Expand Down Expand Up @@ -184,9 +188,6 @@ Window(const ArgType *arg, ValidityMask &filter, ValidityMask &valid,
ResultType &result, idx_t rid, idx_tbias)
```
Defining `window` is also useful if the aggregate wishes to use a destructive `combine` operation.
This may be tricky to implement efficiently.
### Bind
```cpp
Expand Down
82 changes: 33 additions & 49 deletions src/core_functions/aggregate/nested/list.cpp
Expand Up @@ -67,7 +67,9 @@ static void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_d
}
}

static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &, idx_t count) {
static void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data,
idx_t count) {
D_ASSERT(aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE);

UnifiedVectorFormat states_data;
states_vector.ToUnifiedFormat(count, states_data);
Expand Down Expand Up @@ -147,58 +149,38 @@ static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_d
ListVector::SetListSize(result, total_len);
}

static void ListWindow(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition,
const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result,
idx_t rid) {
static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data,
idx_t count) {

auto &list_bind_data = aggr_input_data.bind_data->Cast<ListBindData>();
LinkedList linked_list;

// UPDATE step

D_ASSERT(partition.input_count == 1);
// FIXME: We are modifying the window operator's data here
auto &input = const_cast<Vector &>(partition.inputs[0]);

// FIXME: we unify more values than necessary (count is frame.end)
const auto count = frames.back().end;

RecursiveUnifiedVectorFormat input_data;
Vector::RecursiveToUnifiedFormat(input, count, input_data);

for (const auto &frame : frames) {
for (idx_t i = frame.start; i < frame.end; i++) {
list_bind_data.functions.AppendRow(aggr_input_data.allocator, linked_list, input_data, i);
}
// Can we use destructive combining?
if (aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE) {
ListAbsorbFunction(states_vector, combined, aggr_input_data, count);
return;
}

// FINALIZE step

D_ASSERT(result.GetType().id() == LogicalTypeId::LIST);
auto result_data = FlatVector::GetData<list_entry_t>(result);
size_t total_len = ListVector::GetListSize(result);
UnifiedVectorFormat states_data;
states_vector.ToUnifiedFormat(count, states_data);
auto states_ptr = UnifiedVectorFormat::GetData<const ListAggState *>(states_data);
auto combined_ptr = FlatVector::GetData<ListAggState *>(combined);

// set the length and offset of this list in the result vector
result_data[rid].offset = total_len;
result_data[rid].length = linked_list.total_capacity;
auto &list_bind_data = aggr_input_data.bind_data->Cast<ListBindData>();
auto result_type = ListType::GetChildType(list_bind_data.stype);

// Empty frames produce NULL to track PG
if (!linked_list.total_capacity) {
auto &mask = FlatVector::Validity(result);
mask.SetInvalid(rid);
return;
}
for (idx_t i = 0; i < count; i++) {
auto &source = *states_ptr[states_data.sel->get_index(i)];
auto &target = *combined_ptr[i];

D_ASSERT(linked_list.total_capacity != 0);
total_len += linked_list.total_capacity;
const auto entry_count = source.linked_list.total_capacity;
Vector input(result_type, source.linked_list.total_capacity);
list_bind_data.functions.BuildListVector(source.linked_list, input, 0);

// reserve capacity, then copy over the data to the child vector
ListVector::Reserve(result, total_len);
auto &result_child = ListVector::GetEntry(result);
idx_t offset = result_data[rid].offset;
list_bind_data.functions.BuildListVector(linked_list, result_child, offset);
RecursiveUnifiedVectorFormat input_data;
Vector::RecursiveToUnifiedFormat(input, entry_count, input_data);

ListVector::SetListSize(result, total_len);
for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) {
list_bind_data.functions.AppendRow(aggr_input_data.allocator, target.linked_list, input_data, entry_idx);
}
}
}

unique_ptr<FunctionData> ListBindFunction(ClientContext &context, AggregateFunction &function,
Expand All @@ -217,10 +199,12 @@ unique_ptr<FunctionData> ListBindFunction(ClientContext &context, AggregateFunct
}

AggregateFunction ListFun::GetFunction() {
return AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize<ListAggState>,
AggregateFunction::StateInitialize<ListAggState, ListFunction>, ListUpdateFunction,
ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr,
ListWindow);
auto func =
AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize<ListAggState>,
AggregateFunction::StateInitialize<ListAggState, ListFunction>, ListUpdateFunction,
ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, nullptr);

return func;
}

} // namespace duckdb
Expand Up @@ -354,7 +354,8 @@ SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &cont
Vector source_state(Value::POINTER(CastPointerToValue(lstate.state.aggregates[aggr_idx].get())));
Vector dest_state(Value::POINTER(CastPointerToValue(gstate.state.aggregates[aggr_idx].get())));

AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator);
AggregateInputData aggr_input_data(aggregate.bind_info.get(), gstate.allocator,
AggregateCombineType::ALLOW_DESTRUCTIVE);
aggregate.function.combine(source_state, dest_state, aggr_input_data, 1);
#ifdef DEBUG
gstate.state.counts[aggr_idx] += lstate.state.counts[aggr_idx];
Expand Down Expand Up @@ -548,7 +549,8 @@ void UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() {
}

auto &aggregate = aggregates[agg_idx]->Cast<BoundAggregateExpression>();
AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator);
AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator,
AggregateCombineType::ALLOW_DESTRUCTIVE);

Vector state_vec(Value::POINTER(CastPointerToValue(state.aggregates[agg_idx].get())));
Vector combined_vec(Value::POINTER(CastPointerToValue(gstate.state.aggregates[agg_idx].get())));
Expand Down

0 comments on commit 41d90ca

Please sign in to comment.