From 455fde93b208fb58c45f4231e38f3e6481fa93b7 Mon Sep 17 00:00:00 2001 From: Laurens Kuiper Date: Fri, 9 Feb 2024 10:03:37 +0100 Subject: [PATCH] limit tasks to threads and make sure radix ht has initial reservation --- .../operator/aggregate/physical_hash_aggregate.cpp | 13 +++++++------ .../aggregate/physical_ungrouped_aggregate.cpp | 9 +++++---- src/execution/radix_partitioned_hashtable.cpp | 11 ++++++++--- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/execution/operator/aggregate/physical_hash_aggregate.cpp index e6bd3587525..f1cc27a01df 100644 --- a/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_hash_aggregate.cpp @@ -2,6 +2,7 @@ #include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" #include "duckdb/common/atomic.hpp" +#include "duckdb/common/optional_idx.hpp" #include "duckdb/common/vector_operations/vector_operations.hpp" #include "duckdb/execution/aggregate_hashtable.hpp" #include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" @@ -14,7 +15,6 @@ #include "duckdb/planner/expression/bound_aggregate_expression.hpp" #include "duckdb/planner/expression/bound_constant_expression.hpp" #include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/common/optional_idx.hpp" namespace duckdb { @@ -565,9 +565,10 @@ class HashAggregateDistinctFinalizeTask : public ExecutorTask { }; void HashAggregateDistinctFinalizeEvent::Schedule() { - const auto n_threads = CreateGlobalSources(); + auto n_tasks = CreateGlobalSources(); + n_tasks = MinValue(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads()); vector> tasks; - for (idx_t i = 0; i < n_threads; i++) { + for (idx_t i = 0; i < n_tasks; i++) { tasks.push_back(make_uniq(*pipeline, shared_from_this(), op, gstate)); } SetTasks(std::move(tasks)); @@ -577,7 +578,7 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() { auto &aggregates = op.grouped_aggregate_data.aggregates; global_source_states.reserve(op.groupings.size()); - idx_t n_threads = 0; + idx_t n_tasks = 0; for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { auto &grouping = op.groupings[grouping_idx]; auto &distinct_state = *gstate.grouping_states[grouping_idx].distinct_state; @@ -597,13 +598,13 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() { auto table_idx = distinct_data.info.table_map.at(agg_idx); auto &radix_table_p = distinct_data.radix_tables[table_idx]; - n_threads += radix_table_p->MaxThreads(*distinct_state.radix_states[table_idx]); + n_tasks += radix_table_p->MaxThreads(*distinct_state.radix_states[table_idx]); aggregate_sources.push_back(radix_table_p->GetGlobalSourceState(context)); } global_source_states.push_back(std::move(aggregate_sources)); } - return MaxValue(n_threads, 1); + return MaxValue(n_tasks, 1); } void HashAggregateDistinctFinalizeEvent::FinishEvent() { diff --git a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp index 693d17a0380..2401e99c854 100644 --- a/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ b/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp @@ -431,7 +431,7 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() { auto &aggregates = op.aggregates; auto &distinct_data = *op.distinct_data; - idx_t n_threads = 0; + idx_t n_tasks = 0; idx_t payload_idx = 0; idx_t next_payload_idx = 0; for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { @@ -451,13 +451,14 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() { // Create global state for scanning auto table_idx = distinct_data.info.table_map.at(agg_idx); auto &radix_table_p = *distinct_data.radix_tables[table_idx]; - n_threads += radix_table_p.MaxThreads(*gstate.distinct_state->radix_states[table_idx]); + n_tasks += radix_table_p.MaxThreads(*gstate.distinct_state->radix_states[table_idx]); global_source_states.push_back(radix_table_p.GetGlobalSourceState(context)); } - n_threads = MaxValue(n_threads, 1); + n_tasks = MaxValue(n_tasks, 1); + n_tasks = MinValue(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads()); vector> tasks; - for (idx_t i = 0; i < n_threads; i++) { + for (idx_t i = 0; i < n_tasks; i++) { tasks.push_back( make_uniq(pipeline->executor, shared_from_this(), op, gstate)); tasks_scheduled++; diff --git a/src/execution/radix_partitioned_hashtable.cpp b/src/execution/radix_partitioned_hashtable.cpp index 31137434760..952fddb0517 100644 --- a/src/execution/radix_partitioned_hashtable.cpp +++ b/src/execution/radix_partitioned_hashtable.cpp @@ -365,7 +365,9 @@ bool MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, Ra thread_limit = temporary_memory_state.GetReservation() / gstate.active_threads; if (total_size > thread_limit) { // Out-of-core would be triggered below, try to increase the reservation - temporary_memory_state.SetRemainingSize(context, 2 * temporary_memory_state.GetRemainingSize()); + auto remaining_size = + MaxValue(gstate.active_threads * total_size, temporary_memory_state.GetRemainingSize()); + temporary_memory_state.SetRemainingSize(context, 2 * remaining_size); thread_limit = temporary_memory_state.GetReservation() / gstate.active_threads; } } @@ -541,9 +543,12 @@ idx_t RadixPartitionedHashTable::MaxThreads(GlobalSinkState &sink_p) const { // This many partitions will fit given our reservation (at least 1)) auto partitions_fit = MaxValue(sink.temporary_memory_state->GetReservation() / sink.max_partition_size, 1); + // Maximum is either the number of partitions, or the number of threads + auto max_possible = + MinValue(sink.partitions.size(), TaskScheduler::GetScheduler(sink.context).NumberOfThreads()); - // Of course, limit it to the number of actual partitions - return MinValue(sink.partitions.size(), partitions_fit); + // Mininum of the two + return MinValue(partitions_fit, max_possible); } void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) {