Skip to content

Commit

Permalink
Merge pull request duckdb#9989 from lnkuiper/issue9718
Browse files Browse the repository at this point in the history
Improve progress bar for Aggregation and limit threads for large data sizes
  • Loading branch information
Mytherin authored and ywelsch committed Mar 19, 2024
1 parent 9a4ee6a commit 029fe92
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 23 deletions.
9 changes: 8 additions & 1 deletion src/execution/aggregate_hashtable.cpp
Expand Up @@ -512,7 +512,7 @@ void GroupedAggregateHashTable::Combine(GroupedAggregateHashTable &other) {
}
}

void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data) {
void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data, optional_ptr<atomic<double>> progress) {
D_ASSERT(other_data.GetLayout().GetAggrWidth() == layout.GetAggrWidth());
D_ASSERT(other_data.GetLayout().GetDataWidth() == layout.GetDataWidth());
D_ASSERT(other_data.GetLayout().GetRowWidth() == layout.GetRowWidth());
Expand All @@ -523,6 +523,9 @@ void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data) {

FlushMoveState fm_state(other_data);
RowOperationsState row_state(*aggregate_allocator);

idx_t chunk_idx = 0;
const auto chunk_count = other_data.ChunkCount();
while (fm_state.Scan()) {
FindOrCreateGroups(fm_state.groups, fm_state.hashes, fm_state.group_addresses, fm_state.new_groups_sel);
RowOperations::CombineStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations,
Expand All @@ -531,6 +534,10 @@ void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data) {
RowOperations::DestroyStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations,
fm_state.groups.size());
}

if (progress) {
*progress = double(++chunk_idx) / double(chunk_count);
}
}

Verify();
Expand Down
17 changes: 14 additions & 3 deletions src/execution/operator/aggregate/physical_hash_aggregate.cpp
Expand Up @@ -782,13 +782,13 @@ class HashAggregateGlobalSourceState : public GlobalSourceState {
}

auto &ht_state = op.sink_state->Cast<HashAggregateGlobalSinkState>();
idx_t partitions = 0;
idx_t threads = 0;
for (size_t sidx = 0; sidx < op.groupings.size(); ++sidx) {
auto &grouping = op.groupings[sidx];
auto &grouping_gstate = ht_state.grouping_states[sidx];
partitions += grouping.table_data.NumberOfPartitions(*grouping_gstate.table_state);
threads += grouping.table_data.MaxThreads(*grouping_gstate.table_state);
}
return MaxValue<idx_t>(1, partitions);
return MaxValue<idx_t>(1, threads);
}
};

Expand Down Expand Up @@ -850,6 +850,17 @@ SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataC
return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT;
}

double PhysicalHashAggregate::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const {
auto &sink_gstate = sink_state->Cast<HashAggregateGlobalSinkState>();
auto &gstate = gstate_p.Cast<HashAggregateGlobalSourceState>();
double total_progress = 0;
for (idx_t radix_idx = 0; radix_idx < groupings.size(); radix_idx++) {
total_progress += groupings[radix_idx].table_data.GetProgress(
context, *sink_gstate.grouping_states[radix_idx].table_state, *gstate.radix_states[radix_idx]);
}
return total_progress / double(groupings.size());
}

string PhysicalHashAggregate::ParamsToString() const {
string result;
auto &groups = grouped_aggregate_data.groups;
Expand Down
17 changes: 9 additions & 8 deletions src/execution/physical_plan/plan_create_table.cpp
@@ -1,16 +1,17 @@
#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp"
#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp"
#include "duckdb/catalog/duck_catalog.hpp"
#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp"
#include "duckdb/execution/operator/persistent/physical_insert.hpp"
#include "duckdb/execution/operator/schema/physical_create_table.hpp"
#include "duckdb/execution/physical_plan_generator.hpp"
#include "duckdb/main/config.hpp"
#include "duckdb/parallel/task_scheduler.hpp"
#include "duckdb/parser/parsed_data/create_table_info.hpp"
#include "duckdb/execution/operator/persistent/physical_insert.hpp"
#include "duckdb/planner/constraints/bound_check_constraint.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/planner/operator/logical_create_table.hpp"
#include "duckdb/main/config.hpp"
#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp"
#include "duckdb/planner/constraints/bound_check_constraint.hpp"
#include "duckdb/parallel/task_scheduler.hpp"
#include "duckdb/catalog/duck_catalog.hpp"

namespace duckdb {

Expand All @@ -21,10 +22,10 @@ unique_ptr<PhysicalOperator> DuckCatalog::PlanCreateTableAs(ClientContext &conte
auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads();
unique_ptr<PhysicalOperator> create;
if (!parallel_streaming_insert && use_batch_index) {
create = make_uniq<PhysicalBatchInsert>(op, op.schema, std::move(op.info), op.estimated_cardinality);
create = make_uniq<PhysicalBatchInsert>(op, op.schema, std::move(op.info), 0);

} else {
create = make_uniq<PhysicalInsert>(op, op.schema, std::move(op.info), op.estimated_cardinality,
create = make_uniq<PhysicalInsert>(op, op.schema, std::move(op.info), 0,
parallel_streaming_insert && num_threads > 1);
}

Expand Down
76 changes: 67 additions & 9 deletions src/execution/radix_partitioned_hashtable.cpp
Expand Up @@ -69,9 +69,11 @@ unique_ptr<GroupedAggregateHashTable> RadixPartitionedHashTable::CreateHT(Client
// Sink
//===--------------------------------------------------------------------===//
struct AggregatePartition {
explicit AggregatePartition(unique_ptr<TupleDataCollection> data_p) : data(std::move(data_p)), finalized(false) {
explicit AggregatePartition(unique_ptr<TupleDataCollection> data_p)
: data(std::move(data_p)), progress(0), finalized(false) {
}
unique_ptr<TupleDataCollection> data;
atomic<double> progress;
atomic<bool> finalized;
};

Expand Down Expand Up @@ -135,6 +137,8 @@ class RadixHTGlobalSinkState : public GlobalSinkState {
void Destroy();

public:
ClientContext &context;

//! The radix HT
const RadixPartitionedHashTable &radix_ht;
//! Config for partitioning
Expand Down Expand Up @@ -168,10 +172,10 @@ class RadixHTGlobalSinkState : public GlobalSinkState {
idx_t count_before_combining;
};

RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht_p)
: radix_ht(radix_ht_p), config(context, *this), finalized(false), external(false), active_threads(0),
any_combined(false), finalize_idx(0), scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE),
count_before_combining(0) {
RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const RadixPartitionedHashTable &radix_ht_p)
: context(context_p), radix_ht(radix_ht_p), config(context, *this), finalized(false), external(false),
active_threads(0), any_combined(false), finalize_idx(0),
scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), count_before_combining(0) {
}

RadixHTGlobalSinkState::~RadixHTGlobalSinkState() {
Expand Down Expand Up @@ -479,9 +483,32 @@ void RadixPartitionedHashTable::Finalize(ClientContext &, GlobalSinkState &gstat
//===--------------------------------------------------------------------===//
// Source
//===--------------------------------------------------------------------===//
idx_t RadixPartitionedHashTable::NumberOfPartitions(GlobalSinkState &sink_p) const {
idx_t RadixPartitionedHashTable::MaxThreads(GlobalSinkState &sink_p) const {
auto &sink = sink_p.Cast<RadixHTGlobalSinkState>();
return sink.partitions.size();
if (sink.partitions.empty()) {
return 0;
}

// We take the largest partition as an example
reference<TupleDataCollection> largest_partition = *sink.partitions[0]->data;
for (idx_t i = 1; i < sink.partitions.size(); i++) {
auto &partition = *sink.partitions[i]->data;
if (partition.Count() > largest_partition.get().Count()) {
largest_partition = partition;
}
}

// Worst-case size if every value is unique
const auto maximum_combined_partition_size =
GroupedAggregateHashTable::GetCapacityForCount(largest_partition.get().Count()) * sizeof(aggr_ht_entry_t) +
largest_partition.get().SizeInBytes();

// How many of these can we fit in 60% of memory
const idx_t memory_limit = 0.6 * BufferManager::GetBufferManager(sink.context).GetMaxMemory();
const auto partitions_that_fit = MaxValue<idx_t>(memory_limit / maximum_combined_partition_size, 1);

// Of course, limit it to the number of threads
return MinValue<idx_t>(sink.partitions.size(), partitions_that_fit);
}

void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) {
Expand Down Expand Up @@ -649,7 +676,17 @@ void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlob
if (!ht) {
// Create a HT with sufficient capacity
const auto capacity = GroupedAggregateHashTable::GetCapacityForCount(partition.data->Count());
ht = sink.radix_ht.CreateHT(gstate.context, capacity, 0);

// However, we will limit the initial capacity so we don't do a huge over-allocation
const idx_t n_threads = TaskScheduler::GetScheduler(gstate.context).NumberOfThreads();
const idx_t memory_limit = BufferManager::GetBufferManager(gstate.context).GetMaxMemory();
const idx_t thread_limit = 0.6 * memory_limit / n_threads;

const idx_t size_per_entry = partition.data->SizeInBytes() / partition.data->Count() +
idx_t(GroupedAggregateHashTable::LOAD_FACTOR * sizeof(aggr_ht_entry_t));
const auto capacity_limit = NextPowerOfTwo(thread_limit / size_per_entry);

ht = sink.radix_ht.CreateHT(gstate.context, MinValue<idx_t>(capacity, capacity_limit), 0);
} else {
// We may want to resize here to the size of this partition, but for now we just assume uniform partition sizes
ht->InitializePartitionedData();
Expand All @@ -658,7 +695,7 @@ void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlob
}

// Now combine the uncombined data using this thread's HT
ht->Combine(*partition.data);
ht->Combine(*partition.data, &partition.progress);
ht->UnpinData();

// Move the combined data back to the partition
Expand Down Expand Up @@ -812,4 +849,25 @@ SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, D
}
}

double RadixPartitionedHashTable::GetProgress(ClientContext &, GlobalSinkState &sink_p,
GlobalSourceState &gstate_p) const {
auto &sink = sink_p.Cast<RadixHTGlobalSinkState>();
auto &gstate = gstate_p.Cast<RadixHTGlobalSourceState>();

// Get partition combine progress, weigh it 2x
double total_progress = 0;
for (auto &partition : sink.partitions) {
total_progress += partition->progress * 2.0;
}

// Get scan progress, weigh it 1x
total_progress += gstate.scan_done;

// Divide by 3x for the weights, and the number of partitions to get a value between 0 and 1 again
total_progress /= 3.0 * sink.partitions.size();

// Multiply by 100 to get a percentage
return 100.0 * total_progress;
}

} // namespace duckdb
2 changes: 1 addition & 1 deletion src/include/duckdb/execution/aggregate_hashtable.hpp
Expand Up @@ -138,7 +138,7 @@ class GroupedAggregateHashTable : public BaseAggregateHashTable {

//! Executes the filter(if any) and update the aggregates
void Combine(GroupedAggregateHashTable &other);
void Combine(TupleDataCollection &other_data);
void Combine(TupleDataCollection &other_data, optional_ptr<atomic<double>> progress = nullptr);

//! Unpins the data blocks
void UnpinData();
Expand Down
Expand Up @@ -93,6 +93,8 @@ class PhysicalHashAggregate : public PhysicalOperator {
GlobalSourceState &gstate) const override;
SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const override;

double GetProgress(ClientContext &context, GlobalSourceState &gstate) const override;

bool IsSource() const override {
return true;
}
Expand Down
4 changes: 3 additions & 1 deletion src/include/duckdb/execution/radix_partitioned_hashtable.hpp
Expand Up @@ -50,8 +50,10 @@ class RadixPartitionedHashTable {
SourceResultType GetData(ExecutionContext &context, DataChunk &chunk, GlobalSinkState &sink,
OperatorSourceInput &input) const;

double GetProgress(ClientContext &context, GlobalSinkState &sink_p, GlobalSourceState &gstate) const;

const TupleDataLayout &GetLayout() const;
idx_t NumberOfPartitions(GlobalSinkState &sink) const;
idx_t MaxThreads(GlobalSinkState &sink) const;
static void SetMultiScan(GlobalSinkState &sink);

private:
Expand Down

0 comments on commit 029fe92

Please sign in to comment.