Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Filter Clause for aggregates #1309

Merged
merged 11 commits into from Jan 26, 2021
45 changes: 38 additions & 7 deletions src/execution/aggregate_hashtable.cpp
@@ -1,13 +1,14 @@
#include "duckdb/execution/aggregate_hashtable.hpp"

#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/common/algorithm.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/operator/comparison_operators.hpp"
#include "duckdb/common/types/null_value.hpp"
#include "duckdb/common/vector_operations/unary_executor.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/common/vector_operations/unary_executor.hpp"
#include "duckdb/common/operator/comparison_operators.hpp"
#include "duckdb/common/algorithm.hpp"
#include "duckdb/storage/buffer_manager.hpp"

#include <cmath>
Expand Down Expand Up @@ -274,6 +275,7 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashe
// for any entries for which a group was found, update the aggregate
auto &aggr = aggregates[aggr_idx];
auto input_count = (idx_t)aggr.child_count;

if (aggr.distinct) {
// construct chunk for secondary hash table probing
vector<LogicalType> probe_types(group_types);
Expand Down Expand Up @@ -308,10 +310,39 @@ idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashe
}

distinct_addresses.Verify(new_group_count);

aggr.function.update(input_count == 0 ? nullptr : &payload.data[payload_idx], input_count,
distinct_addresses, new_group_count);
if (aggr.filter) {
pdet marked this conversation as resolved.
Show resolved Hide resolved
ExpressionExecutor filter_execution(aggr.filter);
SelectionVector true_sel(STANDARD_VECTOR_SIZE);
auto count = filter_execution.SelectExpression(payload, true_sel);
DataChunk filtered_payload;
auto pay_types = payload.GetTypes();
filtered_payload.Initialize(pay_types);
filtered_payload.Slice(payload, true_sel, count);
Vector filtered_addresses;
filtered_addresses.Slice(distinct_addresses, true_sel, count);
filtered_addresses.Normalify(count);
aggr.function.update(input_count == 0 ? nullptr : &filtered_payload.data[payload_idx], input_count,
filtered_addresses, filtered_payload.size());
} else {
aggr.function.update(input_count == 0 ? nullptr : &payload.data[payload_idx], input_count,
distinct_addresses, new_group_count);
}
}
} else if (aggr.filter) {
ExpressionExecutor filter_execution(aggr.filter);
SelectionVector true_sel(STANDARD_VECTOR_SIZE);
auto count = filter_execution.SelectExpression(payload, true_sel);
DataChunk filtered_payload;
auto pay_types = payload.GetTypes();
filtered_payload.Initialize(pay_types);
filtered_payload.Slice(payload, true_sel, count);
Vector filtered_addresses;
filtered_addresses.Slice(addresses, true_sel, count);
filtered_addresses.Normalify(count);
aggr.function.update(input_count == 0 ? nullptr : &filtered_payload.data[payload_idx], input_count,
filtered_addresses, filtered_payload.size());
payload_idx += input_count;
VectorOperations::AddInPlace(filtered_addresses, aggr.payload_size, filtered_payload.size());
} else {
aggr.function.update(input_count == 0 ? nullptr : &payload.data[payload_idx], input_count, addresses,
payload.size());
Expand Down
2 changes: 1 addition & 1 deletion src/execution/base_aggregate_hashtable.cpp
Expand Up @@ -11,7 +11,7 @@ vector<AggregateObject> AggregateObject::CreateAggregateObjects(vector<BoundAggr
payload_size = BaseAggregateHashTable::Align(payload_size);
#endif
aggregates.push_back(AggregateObject(binding->function, binding->bind_info.get(), binding->children.size(),
payload_size, binding->distinct, binding->return_type.InternalType()));
payload_size, binding->distinct, binding->return_type.InternalType(),binding->filter.get()));
}
return aggregates;
}
Expand Down
2 changes: 1 addition & 1 deletion src/execution/expression_executor.cpp
Expand Up @@ -40,7 +40,7 @@ void ExpressionExecutor::Execute(DataChunk *input, DataChunk &result) {
SetChunk(input);

D_ASSERT(expressions.size() == result.ColumnCount());
D_ASSERT(expressions.size() > 0);
D_ASSERT(!expressions.empty());
for (idx_t i = 0; i < expressions.size(); i++) {
ExecuteExpression(i, result.data[i]);
}
Expand Down
63 changes: 50 additions & 13 deletions src/execution/operator/aggregate/physical_hash_aggregate.cpp
@@ -1,15 +1,14 @@
#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp"

#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/execution/aggregate_hashtable.hpp"
#include "duckdb/execution/partitionable_hashtable.hpp"

#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/main/client_context.hpp"
#include "duckdb/parallel/pipeline.hpp"
#include "duckdb/parallel/task_scheduler.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp"
#include "duckdb/main/client_context.hpp"

namespace duckdb {

Expand All @@ -24,7 +23,7 @@ PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<Logi
: PhysicalSink(type, types), groups(move(groups_p)), all_combinable(true), any_distinct(false) {
// get a list of all aggregates to be computed
// fake a single group with a constant value for aggregation without groups
if (this->groups.size() == 0) {
if (this->groups.empty()) {
group_types.push_back(LogicalType::TINYINT);
is_implicit_aggr = true;
} else {
Expand All @@ -45,8 +44,16 @@ PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector<Logi
}

aggregate_return_types.push_back(aggr.return_type);
for (idx_t i = 0; i < aggr.children.size(); ++i) {
payload_types.push_back(aggr.children[i]->return_type);
for (auto &child : aggr.children) {
payload_types.push_back(child->return_type);
}
if (aggr.filter) {
vector<LogicalType> types;
vector<vector<Expression *>> bound_refs;
BoundAggregateExpression::GetColumnRef(aggr.filter.get(), bound_refs, types);
pdet marked this conversation as resolved.
Show resolved Hide resolved
for (auto type : types) {
payload_types.push_back(type);
}
}
if (!aggr.function.combine) {
all_combinable = false;
Expand Down Expand Up @@ -86,12 +93,12 @@ class HashAggregateLocalState : public LocalSinkState {
public:
HashAggregateLocalState(PhysicalHashAggregate &_op) : op(_op), is_empty(true) {
group_chunk.InitializeEmpty(op.group_types);
if (op.payload_types.size() > 0) {
if (!op.payload_types.empty()) {
aggregate_input_chunk.InitializeEmpty(op.payload_types);
}

// if there are no groups we create a fake group so everything has the same group
if (op.groups.size() == 0) {
if (op.groups.empty()) {
group_chunk.data[0].Reference(Value::TINYINT(42));
}
}
Expand Down Expand Up @@ -130,13 +137,39 @@ void PhysicalHashAggregate::Sink(ExecutionContext &context, GlobalOperatorState
group_chunk.data[group_idx].Reference(input.data[bound_ref_expr.index]);
}
idx_t aggregate_input_idx = 0;
for (idx_t i = 0; i < aggregates.size(); i++) {
auto &aggr = (BoundAggregateExpression &)*aggregates[i];
for (auto &aggregate : aggregates) {
auto &aggr = (BoundAggregateExpression &)*aggregate;
for (auto &child_expr : aggr.children) {
D_ASSERT(child_expr->type == ExpressionType::BOUND_REF);
auto &bound_ref_expr = (BoundReferenceExpression &)*child_expr;
aggregate_input_chunk.data[aggregate_input_idx++].Reference(input.data[bound_ref_expr.index]);
}
if (aggr.filter) {
vector<LogicalType> types;
vector<vector<Expression *>> bound_refs;
BoundAggregateExpression::GetColumnRef(aggr.filter.get(), bound_refs, types);
pdet marked this conversation as resolved.
Show resolved Hide resolved
auto f_map = filter_map.find(aggr.filter.get());
if (f_map == filter_map.end()){
unordered_map<size_t,size_t> new_map;
filter_map[aggr.filter.get()] = std::make_pair (true, new_map);
f_map = filter_map.find(aggr.filter.get());
}
for (auto &bound_ref : bound_refs) {
auto &bound_ref_expr = (BoundReferenceExpression &)*bound_ref[0];
if (f_map->second.first) {
aggregate_input_chunk.data[aggregate_input_idx++].Reference(input.data[bound_ref_expr.index]);
f_map->second.second[aggregate_input_idx - 1] = bound_ref_expr.index;
for (auto &bound_ref_up : bound_ref) {
auto &bound_ref_up_expr = (BoundReferenceExpression &)*bound_ref_up;
bound_ref_up_expr.index = aggregate_input_idx - 1;
}
} else {
aggregate_input_chunk.data[aggregate_input_idx].Reference(input.data[f_map->second.second[aggregate_input_idx]]);
aggregate_input_idx++;
}
}
f_map->second.first = false;
}
}

group_chunk.SetCardinality(input.size());
Expand All @@ -151,7 +184,7 @@ void PhysicalHashAggregate::Sink(ExecutionContext &context, GlobalOperatorState
if (ForceSingleHT(state)) {
lock_guard<mutex> glock(gstate.lock);
gstate.is_empty = gstate.is_empty && group_chunk.size() == 0;
if (gstate.finalized_hts.size() == 0) {
if (gstate.finalized_hts.empty()) {
gstate.finalized_hts.push_back(
make_unique<GroupedAggregateHashTable>(BufferManager::GetBufferManager(context.client), group_types,
payload_types, bindings, HtEntryType::HT_WIDTH_64));
Expand Down Expand Up @@ -430,10 +463,14 @@ string PhysicalHashAggregate::ParamsToString() const {
result += groups[i]->GetName();
}
for (idx_t i = 0; i < aggregates.size(); i++) {
if (i > 0 || groups.size() > 0) {
auto &aggregate = (BoundAggregateExpression &)*aggregates[i];
if (i > 0 || !groups.empty()) {
result += "\n";
}
result += aggregates[i]->GetName();
if (aggregate.filter) {
result += aggregate.filter->GetName();
pdet marked this conversation as resolved.
Show resolved Hide resolved
}
}
return result;
}
Expand Down
Expand Up @@ -35,8 +35,16 @@ PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &contex

D_ASSERT(!aggr.distinct);
D_ASSERT(aggr.function.combine);
for (idx_t i = 0; i < aggr.children.size(); ++i) {
payload_types.push_back(aggr.children[i]->return_type);
for (auto & child : aggr.children) {
payload_types.push_back(child->return_type);
}
if (aggr.filter){
vector<LogicalType> types;
vector<vector<Expression*>> bound_refs;
BoundAggregateExpression::GetColumnRef(aggr.filter.get(),bound_refs,types);
pdet marked this conversation as resolved.
Show resolved Hide resolved
for (auto type:types){
payload_types.push_back(type);
}
}
}
aggregate_objects = AggregateObject::CreateAggregateObjects(move(bindings));
Expand Down Expand Up @@ -67,7 +75,7 @@ class PerfectHashAggregateLocalState : public LocalSinkState {
PerfectHashAggregateLocalState(PhysicalPerfectHashAggregate &op, ClientContext &context)
: ht(op.CreateHT(context)) {
group_chunk.InitializeEmpty(op.group_types);
if (op.payload_types.size() > 0) {
if (!op.payload_types.empty()) {
aggregate_input_chunk.InitializeEmpty(op.payload_types);
}
}
Expand Down Expand Up @@ -99,16 +107,44 @@ void PhysicalPerfectHashAggregate::Sink(ExecutionContext &context, GlobalOperato
group_chunk.data[group_idx].Reference(input.data[bound_ref_expr.index]);
}
idx_t aggregate_input_idx = 0;
for (idx_t i = 0; i < aggregates.size(); i++) {
auto &aggr = (BoundAggregateExpression &)*aggregates[i];
for (auto & aggregate : aggregates) {
auto &aggr = (BoundAggregateExpression &)*aggregate;
for (auto &child_expr : aggr.children) {
D_ASSERT(child_expr->type == ExpressionType::BOUND_REF);
auto &bound_ref_expr = (BoundReferenceExpression &)*child_expr;
aggregate_input_chunk.data[aggregate_input_idx++].Reference(input.data[bound_ref_expr.index]);
}
if (aggr.filter) {
vector<LogicalType> types;
vector<vector<Expression *>> bound_refs;
BoundAggregateExpression::GetColumnRef(aggr.filter.get(), bound_refs, types);
pdet marked this conversation as resolved.
Show resolved Hide resolved
auto f_map = filter_map.find(aggr.filter.get());
if (f_map == filter_map.end()){
unordered_map<size_t,size_t> new_map;
filter_map[aggr.filter.get()] = std::make_pair (true, new_map);
f_map = filter_map.find(aggr.filter.get());
}
for (auto &bound_ref : bound_refs) {
auto &bound_ref_expr = (BoundReferenceExpression &)*bound_ref[0];
if (f_map->second.first) {
aggregate_input_chunk.data[aggregate_input_idx++].Reference(input.data[bound_ref_expr.index]);
f_map->second.second[aggregate_input_idx - 1] = bound_ref_expr.index;
for (auto &bound_ref_up : bound_ref) {
auto &bound_ref_up_expr = (BoundReferenceExpression &)*bound_ref_up;
bound_ref_up_expr.index = aggregate_input_idx - 1;
}
} else {
aggregate_input_chunk.data[aggregate_input_idx].Reference(input.data[f_map->second.second[aggregate_input_idx]]);
aggregate_input_idx++;
}
}
f_map->second.first = false;
}
}


group_chunk.SetCardinality(input.size());

aggregate_input_chunk.SetCardinality(input.size());

group_chunk.Verify();
Expand Down Expand Up @@ -163,10 +199,14 @@ string PhysicalPerfectHashAggregate::ParamsToString() const {
result += groups[i]->GetName();
}
for (idx_t i = 0; i < aggregates.size(); i++) {
if (i > 0 || groups.size() > 0) {
if (i > 0 || !groups.empty()) {
result += "\n";
}
result += aggregates[i]->GetName();
auto &aggregate = (BoundAggregateExpression &)*aggregates[i];
if (aggregate.filter){
result += aggregate.filter->GetName();
pdet marked this conversation as resolved.
Show resolved Hide resolved
}
}
return result;
}
Expand Down
25 changes: 19 additions & 6 deletions src/execution/operator/aggregate/physical_simple_aggregate.cpp
Expand Up @@ -10,7 +10,7 @@ namespace duckdb {
PhysicalSimpleAggregate::PhysicalSimpleAggregate(vector<LogicalType> types, vector<unique_ptr<Expression>> expressions,
bool all_combinable)
: PhysicalSink(PhysicalOperatorType::SIMPLE_AGGREGATE, move(types)), aggregates(move(expressions)),
all_combinable(all_combinable) {
all_combinable(all_combinable){
}

//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -71,10 +71,10 @@ class SimpleAggregateLocalState : public LocalSinkState {
D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE);
auto &aggr = (BoundAggregateExpression &)*aggregate;
// initialize the payload chunk
if (aggr.children.size()) {
for (idx_t i = 0; i < aggr.children.size(); ++i) {
payload_types.push_back(aggr.children[i]->return_type);
child_executor.AddExpression(*aggr.children[i]);
if (!aggr.children.empty()) {
for (auto & child : aggr.children) {
payload_types.push_back(child->return_type);
child_executor.AddExpression(*child);
}
}
}
Expand Down Expand Up @@ -111,8 +111,17 @@ void PhysicalSimpleAggregate::Sink(ExecutionContext &context, GlobalOperatorStat
for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) {
auto &aggregate = (BoundAggregateExpression &)*aggregates[aggr_idx];
idx_t payload_cnt = 0;
// resolve the filter (if any)
if (aggregate.filter) {
ExpressionExecutor filter_execution (aggregate.filter.get());
SelectionVector true_sel(STANDARD_VECTOR_SIZE);
auto count = filter_execution.SelectExpression(input,true_sel);
input.Slice(true_sel,count);
sink.child_executor.SetChunk(input);
payload_chunk.SetCardinality(input);
}
// resolve the child expressions of the aggregate (if any)
if (aggregate.children.size() > 0) {
if (!aggregate.children.empty()) {
for (idx_t i = 0; i < aggregate.children.size(); ++i) {
sink.child_executor.ExecuteExpression(payload_expr_idx, payload_chunk.data[payload_idx + payload_cnt]);
payload_expr_idx++;
Expand Down Expand Up @@ -176,10 +185,14 @@ void PhysicalSimpleAggregate::GetChunkInternal(ExecutionContext &context, DataCh
string PhysicalSimpleAggregate::ParamsToString() const {
string result;
for (idx_t i = 0; i < aggregates.size(); i++) {
auto &aggregate = (BoundAggregateExpression &)*aggregates[i];
if (i > 0) {
result += "\n";
}
result += aggregates[i]->GetName();
if (aggregate.filter){
result += aggregate.filter->GetName();
pdet marked this conversation as resolved.
Show resolved Hide resolved
}
}
return result;
}
Expand Down
4 changes: 2 additions & 2 deletions src/execution/operator/join/physical_hash_join.cpp
Expand Up @@ -76,7 +76,7 @@ unique_ptr<GlobalOperatorState> PhysicalHashJoin::GetGlobalState(ClientContext &

// jury-rigging the GroupedAggregateHashTable
// we need a count_star and a count to get counts with and without NULLs
aggr = AggregateFunction::BindAggregateFunction(context, CountStarFun::GetFunction(), {}, false);
aggr = AggregateFunction::BindAggregateFunction(context, CountStarFun::GetFunction(), {}, nullptr, false);
correlated_aggregates.push_back(&*aggr);
payload_types.push_back(aggr->return_type);
info.correlated_aggregates.push_back(move(aggr));
Expand All @@ -85,7 +85,7 @@ unique_ptr<GlobalOperatorState> PhysicalHashJoin::GetGlobalState(ClientContext &
vector<unique_ptr<Expression>> children;
// this is a dummy but we need it to make the hash table understand whats going on
children.push_back(make_unique_base<Expression, BoundReferenceExpression>(count_fun.return_type, 0));
aggr = AggregateFunction::BindAggregateFunction(context, count_fun, move(children), false);
aggr = AggregateFunction::BindAggregateFunction(context, count_fun, move(children), nullptr, false);
correlated_aggregates.push_back(&*aggr);
payload_types.push_back(aggr->return_type);
info.correlated_aggregates.push_back(move(aggr));
Expand Down