Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch fix for projections analysis with analyzer. #48357

Merged
merged 4 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 25 additions & 9 deletions src/Interpreters/ActionsDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,10 @@ NameSet ActionsDAG::foldActionsByProjection(
}


ActionsDAGPtr ActionsDAG::foldActionsByProjection(const std::unordered_map<const Node *, std::string> & new_inputs, const NodeRawConstPtrs & required_outputs)
ActionsDAGPtr ActionsDAG::foldActionsByProjection(const std::unordered_map<const Node *, const Node *> & new_inputs, const NodeRawConstPtrs & required_outputs)
{
auto dag = std::make_unique<ActionsDAG>();
std::unordered_map<const Node *, size_t> new_input_to_pos;

std::unordered_map<const Node *, const Node *> inputs_mapping;
std::unordered_map<const Node *, const Node *> mapping;
struct Frame
{
Expand Down Expand Up @@ -796,11 +795,21 @@ ActionsDAGPtr ActionsDAG::foldActionsByProjection(const std::unordered_map<const

if (!node)
{
bool should_rename = !rename.empty() && new_input->result_name != rename;
const auto & input_name = should_rename ? rename : new_input->result_name;
node = &dag->addInput(input_name, new_input->result_type);
if (should_rename)
node = &dag->addAlias(*node, new_input->result_name);
/// It is possible to have a few aliases on the same column.
/// We may want to replace all the aliases,
/// in this case they should have a single input as a child.
auto & mapped_input = inputs_mapping[rename];

if (!mapped_input)
KochetovNicolai marked this conversation as resolved.
Show resolved Hide resolved
{
bool should_rename = new_input->result_name != rename->result_name;
const auto & input_name = should_rename ? rename->result_name : new_input->result_name;
mapped_input = &dag->addInput(input_name, new_input->result_type);
if (should_rename)
mapped_input = &dag->addAlias(*mapped_input, new_input->result_name);
}

node = mapped_input;
}

stack.pop_back();
Expand Down Expand Up @@ -836,7 +845,14 @@ ActionsDAGPtr ActionsDAG::foldActionsByProjection(const std::unordered_map<const
}

for (const auto * output : required_outputs)
dag->outputs.push_back(mapping[output]);
{
/// Keep the names for outputs.
/// Add an alias if the mapped node has a different result name.
const auto * mapped_output = mapping[output];
KochetovNicolai marked this conversation as resolved.
Show resolved Hide resolved
if (output->result_name != mapped_output->result_name)
mapped_output = &dag->addAlias(*mapped_output, output->result_name);
dag->outputs.push_back(mapped_output);
}

return dag;
}
Expand Down
10 changes: 6 additions & 4 deletions src/Interpreters/ActionsDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,11 @@ class ActionsDAG
const String & predicate_column_name = {},
bool add_missing_keys = true);

/// Get an ActionsDAG where:
/// * Subtrees from new_inputs are converted to inputs with specified names.
/// * Outputs are taken from required_outputs.
/// Get an ActionsDAG in a following way:
/// * Traverse a tree starting from required_outputs
/// * If there is a node from new_inputs keys, replace it to INPUT
/// * INPUT name should be taken from new_inputs mapped node name
/// * Mapped nodes may be the same nodes, and in this case there would be a single INPUT
/// Here want to substitute some expressions to columns from projection.
/// This function expects that all required_outputs can be calculated from nodes in new_inputs.
/// If not, exception will happen.
Expand All @@ -240,7 +242,7 @@ class ActionsDAG
/// \ /
/// c * d - e
static ActionsDAGPtr foldActionsByProjection(
const std::unordered_map<const Node *, std::string> & new_inputs,
const std::unordered_map<const Node *, const Node *> & new_inputs,
const NodeRawConstPtrs & required_outputs);

/// Reorder the output nodes using given position mapping.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,25 @@ static bool hasNullableOrMissingColumn(const DAGIndex & index, const Names & nam
return false;
}

struct AggregateFunctionMatch
{
const AggregateDescription * description = nullptr;
DataTypes argument_types;
};

using AggregateFunctionMatches = std::vector<AggregateFunctionMatch>;

/// Here we try to match aggregate functions from the query to
/// aggregate functions from projection.
bool areAggregatesMatch(
std::optional<AggregateFunctionMatches> matchAggregateFunctions(
const AggregateProjectionInfo & info,
const AggregateDescriptions & aggregates,
const MatchedTrees::Matches & matches,
const DAGIndex & query_index,
const DAGIndex & proj_index)
{
AggregateFunctionMatches res;

/// Index (projection agg function name) -> pos
std::unordered_map<std::string, std::vector<size_t>> projection_aggregate_functions;
for (size_t i = 0; i < info.aggregates.size(); ++i)
Expand All @@ -126,14 +135,20 @@ bool areAggregatesMatch(
// "Cannot match agg func {} by name {}",
// aggregate.column_name, aggregate.function->getName());

return false;
return {};
}

size_t num_args = aggregate.argument_names.size();

DataTypes argument_types;
argument_types.reserve(num_args);

auto & candidates = it->second;
bool found_match = false;

for (size_t idx : candidates)
{
argument_types.clear();
const auto & candidate = info.aggregates[idx];

/// Note: this check is a bit strict.
Expand All @@ -144,9 +159,9 @@ bool areAggregatesMatch(
/// and we can't replace one to another from projection.
if (!candidate.function->getStateType()->equals(*aggregate.function->getStateType()))
{
LOG_TRACE(&Poco::Logger::get("optimizeUseProjections"), "Cannot match agg func {} vs {} by state {} vs {}",
aggregate.column_name, candidate.column_name,
candidate.function->getStateType()->getName(), aggregate.function->getStateType()->getName());
// LOG_TRACE(&Poco::Logger::get("optimizeUseProjections"), "Cannot match agg func {} vs {} by state {} vs {}",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove it completely?

// aggregate.column_name, candidate.column_name,
// candidate.function->getStateType()->getName(), aggregate.function->getStateType()->getName());
continue;
}

Expand All @@ -162,14 +177,14 @@ bool areAggregatesMatch(
{
/// we can ignore arguments for count()
found_match = true;
res.push_back({&candidate, DataTypes()});
break;
}
}

/// Now, function names and types matched.
/// Next, match arguments from DAGs.

size_t num_args = aggregate.argument_names.size();
if (num_args != candidate.argument_names.size())
continue;

Expand Down Expand Up @@ -211,21 +226,52 @@ bool areAggregatesMatch(
break;
}

argument_types.push_back(query_node->result_type);
++next_arg;
}

if (next_arg < aggregate.argument_names.size())
continue;

found_match = true;
res.push_back({&candidate, std::move(argument_types)});
break;
}

if (!found_match)
return false;
return {};
}

return true;
return res;
}

static void appendAggregateFunctions(
ActionsDAG & proj_dag,
const AggregateDescriptions & aggregates,
const AggregateFunctionMatches & matched_aggregates)
{
std::unordered_map<const AggregateDescription *, const ActionsDAG::Node *> inputs;

/// Just add all the aggregates to dag inputs.
auto & proj_dag_outputs = proj_dag.getOutputs();
size_t num_aggregates = aggregates.size();
for (size_t i = 0; i < num_aggregates; ++i)
{
const auto & aggregate = aggregates[i];
const auto & match = matched_aggregates[i];
auto type = std::make_shared<DataTypeAggregateFunction>(aggregate.function, match.argument_types, aggregate.parameters);

auto & input = inputs[match.description];
if (!input)
input = &proj_dag.addInput(match.description->column_name, std::move(type));

const auto * node = input;

if (node->result_name != aggregate.column_name)
node = &proj_dag.addAlias(*node, aggregate.column_name);

proj_dag_outputs.push_back(node);
}
}

ActionsDAGPtr analyzeAggregateProjection(
Expand All @@ -246,7 +292,8 @@ ActionsDAGPtr analyzeAggregateProjection(
// static_cast<const void *>(match.node), (match.node ? match.node->result_name : ""), match.monotonicity != std::nullopt);
// }

if (!areAggregatesMatch(info, aggregates, matches, query_index, proj_index))
auto matched_aggregates = matchAggregateFunctions(info, aggregates, matches, query_index, proj_index);
if (!matched_aggregates)
return {};

ActionsDAG::NodeRawConstPtrs query_key_nodes;
Expand Down Expand Up @@ -295,7 +342,7 @@ ActionsDAGPtr analyzeAggregateProjection(

std::stack<Frame> stack;
std::unordered_set<const ActionsDAG::Node *> visited;
std::unordered_map<const ActionsDAG::Node *, std::string> new_inputs;
std::unordered_map<const ActionsDAG::Node *, const ActionsDAG::Node *> new_inputs;

for (const auto * key_node : query_key_nodes)
{
Expand All @@ -317,7 +364,7 @@ ActionsDAGPtr analyzeAggregateProjection(
if (match.node && !match.monotonicity && proj_key_nodes.contains(match.node))
{
visited.insert(frame.node);
new_inputs[frame.node] = match.node->result_name;
new_inputs[frame.node] = match.node;
stack.pop();
continue;
}
Expand Down Expand Up @@ -347,12 +394,7 @@ ActionsDAGPtr analyzeAggregateProjection(
// LOG_TRACE(&Poco::Logger::get("optimizeUseProjections"), "Folding actions by projection");

auto proj_dag = query.dag->foldActionsByProjection(new_inputs, query_key_nodes);

/// Just add all the aggregates to dag inputs.
auto & proj_dag_outputs = proj_dag->getOutputs();
for (const auto & aggregate : aggregates)
proj_dag_outputs.push_back(&proj_dag->addInput(aggregate.column_name, aggregate.function->getResultType()));

appendAggregateFunctions(*proj_dag, aggregates, *matched_aggregates);
return proj_dag;
}

Expand Down
4 changes: 2 additions & 2 deletions tests/queries/0_stateless/01710_projection_with_joins.sql
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ drop table t;

drop table if exists mt;
create table mt (id1 Int8, id2 Int8) Engine=MergeTree order by tuple();
select id1 as alias1 from mt all inner join (select id2 as alias1 from mt) as t using (alias1) order by id1 settings allow_experimental_projection_optimization = 1;
select alias1 from (select id1, id1 as alias1 from mt) as l all inner join (select id2 as alias1 from mt) as t using (alias1) order by l.id1 settings allow_experimental_projection_optimization = 1;
select id1 from mt all inner join (select id2 as id1 from mt) as t using (id1) order by id1 settings allow_experimental_projection_optimization = 1;
select id2 as id1 from mt all inner join (select id1 from mt) as t using (id1) order by id1 settings allow_experimental_projection_optimization = 1;
drop table mt;

drop table if exists j;
create table j (id1 Int8, id2 Int8, projection p (select id1, id2 order by id2)) Engine=MergeTree order by id1 settings index_granularity = 1;
insert into j select number, number from numbers(10);
select id1 as alias1 from j all inner join (select id2 as alias1 from j where id2 in (1, 2, 3)) as t using (alias1) where id2 in (2, 3, 4) order by id1 settings allow_experimental_projection_optimization = 1;
select alias1 from (select id1, id1 as alias1 from j) as l all inner join (select id2, id2 as alias1 from j where id2 in (1, 2, 3)) as t using (alias1) where id2 in (2, 3, 4) order by id1 settings allow_experimental_projection_optimization = 1;
drop table j;
4 changes: 2 additions & 2 deletions tests/queries/0_stateless/01710_projections.sql
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ insert into projection_test with rowNumberInAllBlocks() as id select 1, toDateTi
set allow_experimental_projection_optimization = 1, force_optimize_projection = 1;

select * from projection_test; -- { serverError 584 }
select toStartOfMinute(datetime) dt_m, countIf(first_time = 0) from projection_test join (select 1) x using (1) where domain = '1' group by dt_m order by dt_m; -- { serverError 584 }
select toStartOfMinute(datetime) dt_m, countIf(first_time = 0) from projection_test join (select 1) x on 1 where domain = '1' group by dt_m order by dt_m; -- { serverError 584 }

select toStartOfMinute(datetime) dt_m, countIf(first_time = 0) / count(), avg((kbytes * 8) / duration) from projection_test where domain = '1' group by dt_m order by dt_m;

Expand Down Expand Up @@ -39,7 +39,7 @@ select toStartOfMinute(datetime) dt_m, domain, sum(retry_count) / sum(duration),
select toStartOfHour(toStartOfMinute(datetime)) dt_h, uniqHLL12(x_id), uniqHLL12(y_id) from projection_test group by dt_h order by dt_h;

-- found by fuzzer
SET enable_positional_arguments = 0;
SET enable_positional_arguments = 0, force_optimize_projection = 0;
SELECT 2, -1 FROM projection_test PREWHERE domain_alias = 1. WHERE domain = NULL GROUP BY -9223372036854775808 ORDER BY countIf(first_time = 0) / count(-2147483649) DESC NULLS LAST, 1048576 DESC NULLS LAST;

drop table if exists projection_test;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
drop table if exists projection_test__fuzz_0;
set allow_suspicious_low_cardinality_types=1;

CREATE TABLE projection_test__fuzz_0 (`sum(block_count)` UInt64, `domain_alias` UInt64 ALIAS length(domain), `datetime` DateTime, `domain` LowCardinality(String), `x_id` String, `y_id` String, `block_count` Int64, `retry_count` Int64, `duration` Decimal(76, 13), `kbytes` LowCardinality(Int64), `buffer_time` Int64, `first_time` UInt256, `total_bytes` LowCardinality(Nullable(UInt64)), `valid_bytes` Nullable(UInt64), `completed_bytes` Nullable(UInt64), `fixed_bytes` LowCardinality(Nullable(UInt64)), `force_bytes` Int256, PROJECTION p (SELECT toStartOfMinute(datetime) AS dt_m, countIf(first_time = 0) / count(), avg((kbytes * 8) / duration), count(), sum(block_count) / sum(duration), avg(block_count / duration), sum(buffer_time) / sum(duration), avg(buffer_time / duration), sum(valid_bytes) / sum(total_bytes), sum(completed_bytes) / sum(total_bytes), sum(fixed_bytes) / sum(total_bytes), sum(force_bytes) / sum(total_bytes), sum(valid_bytes) / sum(total_bytes), sum(retry_count) / sum(duration), avg(retry_count / duration), countIf(block_count > 0) / count(), countIf(first_time = 0) / count(), uniqHLL12(x_id), uniqHLL12(y_id) GROUP BY dt_m, domain)) ENGINE = MergeTree PARTITION BY toDate(datetime) ORDER BY (toStartOfTenMinutes(datetime), domain) SETTINGS index_granularity_bytes = 10000000;
INSERT INTO projection_test__fuzz_0 SETTINGS max_threads = 1 WITH rowNumberInAllBlocks() AS id SELECT 1, toDateTime('2020-10-24 00:00:00') + (id / 20), toString(id % 100), * FROM generateRandom('x_id String, y_id String, block_count Int64, retry_count Int64, duration Int64, kbytes Int64, buffer_time Int64, first_time Int64, total_bytes Nullable(UInt64), valid_bytes Nullable(UInt64), completed_bytes Nullable(UInt64), fixed_bytes Nullable(UInt64), force_bytes Nullable(UInt64)', 10, 10, 1) LIMIT 1000 SETTINGS max_threads = 1;
SELECT '-21474836.48', 10000000000., '', count(kbytes), '', 10.0001, toStartOfMinute(datetime) AS dt_m, 10, NULL FROM projection_test__fuzz_0 GROUP BY dt_m WITH ROLLUP WITH TOTALS ORDER BY count(retry_count / duration) ASC NULLS LAST, 100000000000000000000. ASC NULLS FIRST format Null;

drop table projection_test__fuzz_0;
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
4 changes: 4 additions & 0 deletions tests/queries/0_stateless/02516_projections_and_context.sql
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
DROP TABLE IF EXISTS test1__fuzz_37;
CREATE TABLE test1__fuzz_37 (`i` Date) ENGINE = MergeTree ORDER BY i;
insert into test1__fuzz_37 values ('2020-10-10');
set allow_experimental_analyzer = 0;
SELECT count() FROM test1__fuzz_37 GROUP BY dictHas(NULL, (dictHas(NULL, (('', materialize(NULL)), materialize(NULL))), 'KeyKey')), dictHas('test_dictionary', tuple(materialize('Ke\0'))), tuple(dictHas(NULL, (tuple('Ke\0Ke\0Ke\0Ke\0Ke\0Ke\0\0\0\0Ke\0'), materialize(NULL)))), 'test_dicti\0nary', (('', materialize(NULL)), dictHas(NULL, (dictHas(NULL, tuple(materialize(NULL))), 'KeyKeyKeyKeyKeyKeyKeyKey')), materialize(NULL)); -- { serverError BAD_ARGUMENTS }
SELECT count() FROM test1__fuzz_37 GROUP BY dictHas('non_existing_dictionary', materialize('a')); -- { serverError BAD_ARGUMENTS }
set allow_experimental_analyzer = 1;
SELECT count() FROM test1__fuzz_37 GROUP BY dictHas(NULL, (dictHas(NULL, (('', materialize(NULL)), materialize(NULL))), 'KeyKey')), dictHas('test_dictionary', tuple(materialize('Ke\0'))), tuple(dictHas(NULL, (tuple('Ke\0Ke\0Ke\0Ke\0Ke\0Ke\0\0\0\0Ke\0'), materialize(NULL)))), 'test_dicti\0nary', (('', materialize(NULL)), dictHas(NULL, (dictHas(NULL, tuple(materialize(NULL))), 'KeyKeyKeyKeyKeyKeyKeyKey')), materialize(NULL));
SELECT count() FROM test1__fuzz_37 GROUP BY dictHas('non_existing_dictionary', materialize('a'));
DROP TABLE test1__fuzz_37;