Skip to content

Commit

Permalink
Merge pull request #10549 from lnkuiper/radix_ht_reservation
Browse files Browse the repository at this point in the history
More `TemporaryMemoryManager` tweaks
  • Loading branch information
Mytherin committed Feb 9, 2024
2 parents 7811335 + 455fde9 commit 179a690
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
13 changes: 7 additions & 6 deletions src/execution/operator/aggregate/physical_hash_aggregate.cpp
Expand Up @@ -2,6 +2,7 @@

#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/common/atomic.hpp"
#include "duckdb/common/optional_idx.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/execution/aggregate_hashtable.hpp"
#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp"
Expand All @@ -14,7 +15,6 @@
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_reference_expression.hpp"
#include "duckdb/common/optional_idx.hpp"

namespace duckdb {

Expand Down Expand Up @@ -565,9 +565,10 @@ class HashAggregateDistinctFinalizeTask : public ExecutorTask {
};

void HashAggregateDistinctFinalizeEvent::Schedule() {
const auto n_threads = CreateGlobalSources();
auto n_tasks = CreateGlobalSources();
n_tasks = MinValue<idx_t>(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads());
vector<shared_ptr<Task>> tasks;
for (idx_t i = 0; i < n_threads; i++) {
for (idx_t i = 0; i < n_tasks; i++) {
tasks.push_back(make_uniq<HashAggregateDistinctFinalizeTask>(*pipeline, shared_from_this(), op, gstate));
}
SetTasks(std::move(tasks));
Expand All @@ -577,7 +578,7 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() {
auto &aggregates = op.grouped_aggregate_data.aggregates;
global_source_states.reserve(op.groupings.size());

idx_t n_threads = 0;
idx_t n_tasks = 0;
for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) {
auto &grouping = op.groupings[grouping_idx];
auto &distinct_state = *gstate.grouping_states[grouping_idx].distinct_state;
Expand All @@ -597,13 +598,13 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() {

auto table_idx = distinct_data.info.table_map.at(agg_idx);
auto &radix_table_p = distinct_data.radix_tables[table_idx];
n_threads += radix_table_p->MaxThreads(*distinct_state.radix_states[table_idx]);
n_tasks += radix_table_p->MaxThreads(*distinct_state.radix_states[table_idx]);
aggregate_sources.push_back(radix_table_p->GetGlobalSourceState(context));
}
global_source_states.push_back(std::move(aggregate_sources));
}

return MaxValue<idx_t>(n_threads, 1);
return MaxValue<idx_t>(n_tasks, 1);
}

void HashAggregateDistinctFinalizeEvent::FinishEvent() {
Expand Down
Expand Up @@ -431,7 +431,7 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() {
auto &aggregates = op.aggregates;
auto &distinct_data = *op.distinct_data;

idx_t n_threads = 0;
idx_t n_tasks = 0;
idx_t payload_idx = 0;
idx_t next_payload_idx = 0;
for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) {
Expand All @@ -451,13 +451,14 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() {
// Create global state for scanning
auto table_idx = distinct_data.info.table_map.at(agg_idx);
auto &radix_table_p = *distinct_data.radix_tables[table_idx];
n_threads += radix_table_p.MaxThreads(*gstate.distinct_state->radix_states[table_idx]);
n_tasks += radix_table_p.MaxThreads(*gstate.distinct_state->radix_states[table_idx]);
global_source_states.push_back(radix_table_p.GetGlobalSourceState(context));
}
n_threads = MaxValue<idx_t>(n_threads, 1);
n_tasks = MaxValue<idx_t>(n_tasks, 1);
n_tasks = MinValue<idx_t>(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads());

vector<shared_ptr<Task>> tasks;
for (idx_t i = 0; i < n_threads; i++) {
for (idx_t i = 0; i < n_tasks; i++) {
tasks.push_back(
make_uniq<UngroupedDistinctAggregateFinalizeTask>(pipeline->executor, shared_from_this(), op, gstate));
tasks_scheduled++;
Expand Down
11 changes: 8 additions & 3 deletions src/execution/radix_partitioned_hashtable.cpp
Expand Up @@ -365,7 +365,9 @@ bool MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, Ra
thread_limit = temporary_memory_state.GetReservation() / gstate.active_threads;
if (total_size > thread_limit) {
// Out-of-core would be triggered below, try to increase the reservation
temporary_memory_state.SetRemainingSize(context, 2 * temporary_memory_state.GetRemainingSize());
auto remaining_size =
MaxValue<idx_t>(gstate.active_threads * total_size, temporary_memory_state.GetRemainingSize());
temporary_memory_state.SetRemainingSize(context, 2 * remaining_size);
thread_limit = temporary_memory_state.GetReservation() / gstate.active_threads;
}
}
Expand Down Expand Up @@ -541,9 +543,12 @@ idx_t RadixPartitionedHashTable::MaxThreads(GlobalSinkState &sink_p) const {

// This many partitions will fit given our reservation (at least 1))
auto partitions_fit = MaxValue<idx_t>(sink.temporary_memory_state->GetReservation() / sink.max_partition_size, 1);
// Maximum is either the number of partitions, or the number of threads
auto max_possible =
MinValue<idx_t>(sink.partitions.size(), TaskScheduler::GetScheduler(sink.context).NumberOfThreads());

// Of course, limit it to the number of actual partitions
return MinValue<idx_t>(sink.partitions.size(), partitions_fit);
// Mininum of the two
return MinValue<idx_t>(partitions_fit, max_possible);
}

void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) {
Expand Down

0 comments on commit 179a690

Please sign in to comment.