Skip to content

Commit

Permalink
Aggregate rework: unify AggregateFunctions and ScalarFunctions Catalo…
Browse files Browse the repository at this point in the history
…gSets and postpone function binding until the binding phase
  • Loading branch information
Mytherin committed Jul 31, 2019
1 parent 7259cab commit 3989acf
Show file tree
Hide file tree
Showing 51 changed files with 274 additions and 376 deletions.
19 changes: 4 additions & 15 deletions src/catalog/catalog.cpp
Expand Up @@ -116,26 +116,15 @@ TableFunctionCatalogEntry *Catalog::GetTableFunction(Transaction &transaction, F
return schema->GetTableFunction(transaction, expression);
}

void Catalog::CreateAggregateFunction(Transaction &transaction, CreateAggregateFunctionInfo *info) {
void Catalog::CreateFunction(Transaction &transaction, CreateFunctionInfo *info) {
auto schema = GetSchema(transaction, info->schema);
schema->CreateAggregateFunction(transaction, info);
schema->CreateFunction(transaction, info);
}

void Catalog::CreateScalarFunction(Transaction &transaction, CreateScalarFunctionInfo *info) {
auto schema = GetSchema(transaction, info->schema);
schema->CreateScalarFunction(transaction, info);
}

AggregateFunctionCatalogEntry *Catalog::GetAggregateFunction(Transaction &transaction, const string &schema_name,
CatalogEntry *Catalog::GetFunction(Transaction &transaction, const string &schema_name,
const string &name, bool if_exists) {
auto schema = GetSchema(transaction, schema_name);
return schema->GetAggregateFunction(transaction, name, if_exists);
}

ScalarFunctionCatalogEntry *Catalog::GetScalarFunction(Transaction &transaction, const string &schema_name,
const string &name) {
auto schema = GetSchema(transaction, schema_name);
return schema->GetScalarFunction(transaction, name);
return schema->GetFunction(transaction, name, if_exists);
}

void Catalog::DropIndex(Transaction &transaction, DropInfo *info) {
Expand Down
63 changes: 19 additions & 44 deletions src/catalog/catalog_entry/schema_catalog_entry.cpp
Expand Up @@ -27,7 +27,7 @@ using namespace std;

SchemaCatalogEntry::SchemaCatalogEntry(Catalog *catalog, string name)
: CatalogEntry(CatalogType::SCHEMA, catalog, name), tables(*catalog), indexes(*catalog), table_functions(*catalog),
aggregate_functions(*catalog), scalar_functions(*catalog), sequences(*catalog) {
functions(*catalog), sequences(*catalog) {
}

void SchemaCatalogEntry::CreateTable(Transaction &transaction, BoundCreateTableInfo *info) {
Expand Down Expand Up @@ -178,62 +178,37 @@ void SchemaCatalogEntry::CreateTableFunction(Transaction &transaction, CreateTab
}
}

void SchemaCatalogEntry::CreateAggregateFunction(Transaction &transaction, CreateAggregateFunctionInfo *info) {
auto aggregate_function = make_unique_base<CatalogEntry, AggregateFunctionCatalogEntry>(catalog, this, info);
void SchemaCatalogEntry::CreateFunction(Transaction &transaction, CreateFunctionInfo *info) {
unique_ptr<CatalogEntry> function;
if (info->type == FunctionType::SCALAR) {
// create a scalar function
function = make_unique_base<CatalogEntry, ScalarFunctionCatalogEntry>(catalog, this, (CreateScalarFunctionInfo*) info);
} else {
// create an aggregate function
function = make_unique_base<CatalogEntry, AggregateFunctionCatalogEntry>(catalog, this, (CreateAggregateFunctionInfo*) info);
}
unordered_set<CatalogEntry *> dependencies{this};

if (!aggregate_functions.CreateEntry(transaction, info->name, move(aggregate_function), dependencies)) {
if (!info->or_replace) {
throw CatalogException("Aggregate function with name \"%s\" already exists!", info->name.c_str());
} else {
auto aggregate_function = make_unique_base<CatalogEntry, AggregateFunctionCatalogEntry>(catalog, this, info);
// function already exists: replace it
if (!aggregate_functions.DropEntry(transaction, info->name, false)) {
throw CatalogException("CREATE OR REPLACE was specified, but "
"aggregate function could not be dropped!");
}
if (!aggregate_functions.CreateEntry(transaction, info->name, move(aggregate_function), dependencies)) {
throw CatalogException("Error in recreating aggregate function in CREATE OR REPLACE");
}
}
if (info->or_replace) {
// replace is set: drop the function if it exists
functions.DropEntry(transaction, info->name, false);
}
}

void SchemaCatalogEntry::CreateScalarFunction(Transaction &transaction, CreateScalarFunctionInfo *info) {
auto scalar_function = make_unique_base<CatalogEntry, ScalarFunctionCatalogEntry>(catalog, this, info);
unordered_set<CatalogEntry *> dependencies{this};

if (!scalar_functions.CreateEntry(transaction, info->name, move(scalar_function), dependencies)) {
if (!functions.CreateEntry(transaction, info->name, move(function), dependencies)) {
if (!info->or_replace) {
throw CatalogException("Scalar function with name \"%s\" already exists!", info->name.c_str());
throw CatalogException("Function with name \"%s\" already exists!", info->name.c_str());
} else {
auto scalar_function = make_unique_base<CatalogEntry, ScalarFunctionCatalogEntry>(catalog, this, info);
// function already exists: replace it
if (!scalar_functions.DropEntry(transaction, info->name, false)) {
throw CatalogException("CREATE OR REPLACE was specified, but "
"function could not be dropped!");
}
if (!scalar_functions.CreateEntry(transaction, info->name, move(scalar_function), dependencies)) {
throw CatalogException("Error in recreating function in CREATE OR REPLACE");
}
throw CatalogException("Error in creating function in CREATE OR REPLACE");
}
}
}

AggregateFunctionCatalogEntry *SchemaCatalogEntry::GetAggregateFunction(Transaction &transaction, const string &name, bool if_exists) {
auto entry = aggregate_functions.GetEntry(transaction, name);
CatalogEntry *SchemaCatalogEntry::GetFunction(Transaction &transaction, const string &name, bool if_exists) {
auto entry = functions.GetEntry(transaction, name);
if (!entry && !if_exists) {
throw CatalogException("Aggregate Function with name %s does not exist!", name.c_str());
}
return (AggregateFunctionCatalogEntry *)entry;
}

ScalarFunctionCatalogEntry *SchemaCatalogEntry::GetScalarFunction(Transaction &transaction, const string &name) {
auto entry = scalar_functions.GetEntry(transaction, name);
if (!entry) {
throw CatalogException("Scalar Function with name %s does not exist!", name.c_str());
}
return (ScalarFunctionCatalogEntry *)entry;
return entry;
}

SequenceCatalogEntry *SchemaCatalogEntry::GetSequence(Transaction &transaction, const string &name) {
Expand Down
2 changes: 0 additions & 2 deletions src/common/symbols.cpp
Expand Up @@ -56,7 +56,6 @@ template class std::unique_ptr<QueryNode>;
template class std::unique_ptr<SelectNode>;
template class std::unique_ptr<SetOperationNode>;
template class std::unique_ptr<ParsedExpression>;
template class std::unique_ptr<AggregateExpression>;
template class std::unique_ptr<CaseExpression>;
template class std::unique_ptr<CastExpression>;
template class std::unique_ptr<ColumnRefExpression>;
Expand Down Expand Up @@ -161,7 +160,6 @@ template class std::unique_ptr<Binder>;
template VECTOR_DEFINITION::const_reference VECTOR_DEFINITION::front() const; \
template VECTOR_DEFINITION::reference VECTOR_DEFINITION::front();

template class std::vector<AggregateExpression *>;
template class std::vector<BoundTable>;
INSTANTIATE_VECTOR(std::vector<ColumnDefinition>);
template class std::vector<ExpressionType>;
Expand Down
3 changes: 1 addition & 2 deletions src/execution/physical_plan/plan_delim_join.cpp
Expand Up @@ -6,7 +6,6 @@
#include "execution/physical_plan_generator.hpp"
#include "planner/operator/logical_delim_join.hpp"
#include "planner/expression/bound_aggregate_expression.hpp"
#include "parser/expression/aggregate_expression.hpp"
#include "main/client_context.hpp"

using namespace duckdb;
Expand Down Expand Up @@ -60,7 +59,7 @@ unique_ptr<PhysicalOperator> PhysicalPlanGenerator::CreatePlan(LogicalDelimJoin
vector<string> aggregate_names = {"count_star", "count"};
vector<BoundAggregateExpression*> correlated_aggregates;
for (index_t i = 0; i < aggregate_names.size(); ++i) {
auto func = context.catalog.GetAggregateFunction(context.ActiveTransaction(), DEFAULT_SCHEMA, aggregate_names[i]);
auto func = (AggregateFunctionCatalogEntry*) context.catalog.GetFunction(context.ActiveTransaction(), DEFAULT_SCHEMA, aggregate_names[i]);
auto aggr = make_unique<BoundAggregateExpression>(payload_types[i], nullptr, func, false);
correlated_aggregates.push_back(&*aggr);
info.correlated_aggregates.push_back(move(aggr));
Expand Down
2 changes: 1 addition & 1 deletion src/execution/physical_plan/plan_distinct.cpp
Expand Up @@ -42,7 +42,7 @@ unique_ptr<PhysicalOperator> PhysicalPlanGenerator::CreateDistinctOn(unique_ptr<
for (index_t i = 0; i < child_projection.select_list.size(); ++i) {
// first we create an aggregate that returns the FIRST element
auto bound = make_unique<BoundReferenceExpression>(types[i], i);
auto first_func = context.catalog.GetAggregateFunction(context.ActiveTransaction(), DEFAULT_SCHEMA, "first");
auto first_func = (AggregateFunctionCatalogEntry*) context.catalog.GetFunction(context.ActiveTransaction(), DEFAULT_SCHEMA, "first");
auto first_aggregate =
make_unique<BoundAggregateExpression>(types[i], move(bound), first_func, false);
// and push it to the list of aggregates
Expand Down
5 changes: 2 additions & 3 deletions src/function/aggregate_function/algebraic.cpp
Expand Up @@ -70,10 +70,9 @@ void avg_finalize(Vector& payloads, Vector &result) {

if (*count_ptr == 0) {
result.nullmask[i] = true;
return;
} else {
((double *)result.data)[i] = *sum_ptr / *count_ptr;
}

((double *)result.data)[i] = *sum_ptr;
});
}

Expand Down
4 changes: 2 additions & 2 deletions src/function/function.cpp
Expand Up @@ -42,7 +42,7 @@ template <class T> static void AddAggregateFunction(Transaction &transaction, Ca
info.return_type = T::GetReturnTypeFunction();
info.cast_arguments = T::GetCastArgumentsFunction();

catalog.CreateAggregateFunction(transaction, &info);
catalog.CreateFunction(transaction, &info);
}

template <class T> static void AddScalarFunction(Transaction &transaction, Catalog &catalog) {
Expand All @@ -57,7 +57,7 @@ template <class T> static void AddScalarFunction(Transaction &transaction, Catal
info.dependency = T::GetDependencyFunction();
info.has_side_effects = T::HasSideEffects();

catalog.CreateScalarFunction(transaction, &info);
catalog.CreateFunction(transaction, &info);
}

void BuiltinFunctions::Initialize(Transaction &transaction, Catalog &catalog) {
Expand Down
19 changes: 5 additions & 14 deletions src/include/catalog/catalog.hpp
Expand Up @@ -20,8 +20,7 @@ struct DropInfo;
struct BoundCreateTableInfo;
struct AlterTableInfo;
struct CreateTableFunctionInfo;
struct CreateAggregateFunctionInfo;
struct CreateScalarFunctionInfo;
struct CreateFunctionInfo;
struct CreateViewInfo;
struct CreateSequenceInfo;

Expand All @@ -30,8 +29,6 @@ class SchemaCatalogEntry;
class TableCatalogEntry;
class SequenceCatalogEntry;
class TableFunctionCatalogEntry;
class AggregateFunctionCatalogEntry;
class ScalarFunctionCatalogEntry;
class StorageManager;

//! The Catalog object represents the catalog of the database.
Expand Down Expand Up @@ -63,10 +60,8 @@ class Catalog {
void AlterTable(ClientContext &context, AlterTableInfo *info);
//! Create a table function in the catalog
void CreateTableFunction(Transaction &transaction, CreateTableFunctionInfo *info);
//! Create an aggregate function in the catalog
void CreateAggregateFunction(Transaction &transaction, CreateAggregateFunctionInfo *info);
//! Create a scalar function in the catalog
void CreateScalarFunction(Transaction &transaction, CreateScalarFunctionInfo *info);
//! Create a scalar or aggregate function in the catalog
void CreateFunction(Transaction &transaction, CreateFunctionInfo *info);

//! Creates a table in the catalog.
void CreateView(Transaction &transaction, CreateViewInfo *info);
Expand All @@ -93,13 +88,9 @@ class Catalog {
//! exception otherwise
TableFunctionCatalogEntry *GetTableFunction(Transaction &transaction, FunctionExpression *expression);

//! Returns a pointer to the aggregate function if it exists, or throws an
//! exception otherwise
AggregateFunctionCatalogEntry *GetAggregateFunction(Transaction &transaction, const string &schema, const string &name, bool if_exists = false);
//! Returns a pointer to the scalar or aggregate function if it exists, or throws an exception otherwise
CatalogEntry *GetFunction(Transaction &transaction, const string &schema, const string &name, bool if_exists = false);

//! Returns a pointer to the scalar function if it exists, or throws an
//! exception otherwise
ScalarFunctionCatalogEntry *GetScalarFunction(Transaction &transaction, const string &schema, const string &name);
//! Drops an index from the catalog.
void DropIndex(Transaction &transaction, DropInfo *info);
};
Expand Down
22 changes: 6 additions & 16 deletions src/include/catalog/catalog_entry/schema_catalog_entry.hpp
Expand Up @@ -16,22 +16,18 @@ class FunctionExpression;

class TableCatalogEntry;
class TableFunctionCatalogEntry;
class AggregateFunctionCatalogEntry;
class ScalarFunctionCatalogEntry;
class SequenceCatalogEntry;

struct AlterTableInfo;
class ClientContext;
struct CreateIndexInfo;
struct CreateTableFunctionInfo;
struct CreateAggregateFunctionInfo;
struct CreateScalarFunctionInfo;
struct CreateFunctionInfo;
struct CreateViewInfo;
struct BoundCreateTableInfo;
struct CreateSequenceInfo;
struct CreateSchemaInfo;
struct CreateTableFunctionInfo;
struct CreateScalarFunctionInfo;
struct DropInfo;

class Transaction;
Expand All @@ -47,10 +43,8 @@ class SchemaCatalogEntry : public CatalogEntry {
CatalogSet indexes;
//! The catalog set holding the table functions
CatalogSet table_functions;
//! The catalog set holding the aggregate functions
CatalogSet aggregate_functions;
//! The catalog set holding the scalar functions
CatalogSet scalar_functions;
//! The catalog set holding the scalar and aggregate functions
CatalogSet functions;
//! The catalog set holding the sequences
CatalogSet sequences;

Expand Down Expand Up @@ -87,15 +81,11 @@ class SchemaCatalogEntry : public CatalogEntry {
TableFunctionCatalogEntry *GetTableFunction(Transaction &transaction, FunctionExpression *expression);
//! Create a table function within the given schema
void CreateTableFunction(Transaction &transaction, CreateTableFunctionInfo *info);
//! Create a aggregate function within the given schema
void CreateAggregateFunction(Transaction &transaction, CreateAggregateFunctionInfo *info);
//! Create a scalar function within the given schema
void CreateScalarFunction(Transaction &transaction, CreateScalarFunctionInfo *info);
//! Create a scalar or aggregate function within the given schema
void CreateFunction(Transaction &transaction, CreateFunctionInfo *info);

//! Gets a scalar function with the given name
AggregateFunctionCatalogEntry *GetAggregateFunction(Transaction &transaction, const string &name, bool if_exists = false);
//! Gets a scalar function with the given name
ScalarFunctionCatalogEntry *GetScalarFunction(Transaction &transaction, const string &name);
CatalogEntry *GetFunction(Transaction &transaction, const string &name, bool if_exists = false);
//! Gets the sequence with the given name
SequenceCatalogEntry *GetSequence(Transaction &transaction, const string &name);

Expand Down
15 changes: 8 additions & 7 deletions src/include/common/enums/catalog_type.hpp
Expand Up @@ -21,13 +21,14 @@ enum class CatalogType : uint8_t {
SCHEMA = 2,
TABLE_FUNCTION = 3,
SCALAR_FUNCTION = 4,
VIEW = 5,
INDEX = 6,
UPDATED_ENTRY = 10,
DELETED_ENTRY = 11,
PREPARED_STATEMENT = 12,
SEQUENCE = 13,
AGGREGATE_FUNCTION = 14
AGGREGATE_FUNCTION = 5,
VIEW = 6,
INDEX = 7,
PREPARED_STATEMENT = 8,
SEQUENCE = 9,

UPDATED_ENTRY = 50,
DELETED_ENTRY = 51,
};

} // namespace duckdb
6 changes: 5 additions & 1 deletion src/include/function/aggregate_function/algebraic.hpp
Expand Up @@ -25,6 +25,10 @@ void avg_update(Vector inputs[], index_t input_count, Vector &result);
void avg_finalize(Vector& payloads, Vector &result);
SQLType avg_get_return_type(vector<SQLType> &arguments);

static Value avg_simple_initialize() {
return Value(TypeId::DOUBLE);
}

class AvgFunction {
public:
static const char*GetName() {
Expand All @@ -48,7 +52,7 @@ class AvgFunction {
}

static aggregate_simple_initialize_t GetSimpleInitializeFunction() {
return nullptr;
return avg_simple_initialize;
}

static aggregate_simple_update_t GetSimpleUpdateFunction() {
Expand Down
44 changes: 0 additions & 44 deletions src/include/parser/expression/aggregate_expression.hpp

This file was deleted.

0 comments on commit 3989acf

Please sign in to comment.