Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More TemporaryMemoryManager tweaks #10549

Merged
merged 1 commit into from Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/execution/operator/aggregate/physical_hash_aggregate.cpp
Expand Up @@ -2,6 +2,7 @@

#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/common/atomic.hpp"
#include "duckdb/common/optional_idx.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/execution/aggregate_hashtable.hpp"
#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp"
Expand All @@ -14,7 +15,6 @@
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_reference_expression.hpp"
#include "duckdb/common/optional_idx.hpp"

namespace duckdb {

Expand Down Expand Up @@ -565,9 +565,10 @@ class HashAggregateDistinctFinalizeTask : public ExecutorTask {
};

void HashAggregateDistinctFinalizeEvent::Schedule() {
const auto n_threads = CreateGlobalSources();
auto n_tasks = CreateGlobalSources();
n_tasks = MinValue<idx_t>(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads());
vector<shared_ptr<Task>> tasks;
for (idx_t i = 0; i < n_threads; i++) {
for (idx_t i = 0; i < n_tasks; i++) {
tasks.push_back(make_uniq<HashAggregateDistinctFinalizeTask>(*pipeline, shared_from_this(), op, gstate));
}
SetTasks(std::move(tasks));
Expand All @@ -577,7 +578,7 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() {
auto &aggregates = op.grouped_aggregate_data.aggregates;
global_source_states.reserve(op.groupings.size());

idx_t n_threads = 0;
idx_t n_tasks = 0;
for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) {
auto &grouping = op.groupings[grouping_idx];
auto &distinct_state = *gstate.grouping_states[grouping_idx].distinct_state;
Expand All @@ -597,13 +598,13 @@ idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() {

auto table_idx = distinct_data.info.table_map.at(agg_idx);
auto &radix_table_p = distinct_data.radix_tables[table_idx];
n_threads += radix_table_p->MaxThreads(*distinct_state.radix_states[table_idx]);
n_tasks += radix_table_p->MaxThreads(*distinct_state.radix_states[table_idx]);
aggregate_sources.push_back(radix_table_p->GetGlobalSourceState(context));
}
global_source_states.push_back(std::move(aggregate_sources));
}

return MaxValue<idx_t>(n_threads, 1);
return MaxValue<idx_t>(n_tasks, 1);
}

void HashAggregateDistinctFinalizeEvent::FinishEvent() {
Expand Down
Expand Up @@ -431,7 +431,7 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() {
auto &aggregates = op.aggregates;
auto &distinct_data = *op.distinct_data;

idx_t n_threads = 0;
idx_t n_tasks = 0;
idx_t payload_idx = 0;
idx_t next_payload_idx = 0;
for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) {
Expand All @@ -451,13 +451,14 @@ void UngroupedDistinctAggregateFinalizeEvent::Schedule() {
// Create global state for scanning
auto table_idx = distinct_data.info.table_map.at(agg_idx);
auto &radix_table_p = *distinct_data.radix_tables[table_idx];
n_threads += radix_table_p.MaxThreads(*gstate.distinct_state->radix_states[table_idx]);
n_tasks += radix_table_p.MaxThreads(*gstate.distinct_state->radix_states[table_idx]);
global_source_states.push_back(radix_table_p.GetGlobalSourceState(context));
}
n_threads = MaxValue<idx_t>(n_threads, 1);
n_tasks = MaxValue<idx_t>(n_tasks, 1);
n_tasks = MinValue<idx_t>(n_tasks, TaskScheduler::GetScheduler(context).NumberOfThreads());

vector<shared_ptr<Task>> tasks;
for (idx_t i = 0; i < n_threads; i++) {
for (idx_t i = 0; i < n_tasks; i++) {
tasks.push_back(
make_uniq<UngroupedDistinctAggregateFinalizeTask>(pipeline->executor, shared_from_this(), op, gstate));
tasks_scheduled++;
Expand Down
11 changes: 8 additions & 3 deletions src/execution/radix_partitioned_hashtable.cpp
Expand Up @@ -365,7 +365,9 @@ bool MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, Ra
thread_limit = temporary_memory_state.GetReservation() / gstate.active_threads;
if (total_size > thread_limit) {
// Out-of-core would be triggered below, try to increase the reservation
temporary_memory_state.SetRemainingSize(context, 2 * temporary_memory_state.GetRemainingSize());
auto remaining_size =
MaxValue<idx_t>(gstate.active_threads * total_size, temporary_memory_state.GetRemainingSize());
temporary_memory_state.SetRemainingSize(context, 2 * remaining_size);
thread_limit = temporary_memory_state.GetReservation() / gstate.active_threads;
}
}
Expand Down Expand Up @@ -541,9 +543,12 @@ idx_t RadixPartitionedHashTable::MaxThreads(GlobalSinkState &sink_p) const {

// This many partitions will fit given our reservation (at least 1))
auto partitions_fit = MaxValue<idx_t>(sink.temporary_memory_state->GetReservation() / sink.max_partition_size, 1);
// Maximum is either the number of partitions, or the number of threads
auto max_possible =
MinValue<idx_t>(sink.partitions.size(), TaskScheduler::GetScheduler(sink.context).NumberOfThreads());

// Of course, limit it to the number of actual partitions
return MinValue<idx_t>(sink.partitions.size(), partitions_fit);
// Mininum of the two
return MinValue<idx_t>(partitions_fit, max_possible);
}

void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) {
Expand Down