Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax ambiguous column check for multiple JOIN ON section #8385

Merged
merged 1 commit into from Dec 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 0 additions & 8 deletions dbms/src/Interpreters/AnalyzedJoin.cpp
Expand Up @@ -98,14 +98,6 @@ NameSet AnalyzedJoin::getQualifiedColumnsSet() const
return out;
}

NameSet AnalyzedJoin::getOriginalColumnsSet() const
{
NameSet out;
for (const auto & names : original_names)
out.insert(names.second);
return out;
}

NamesWithAliases AnalyzedJoin::getNamesWithAliases(const NameSet & required_columns) const
{
NamesWithAliases out;
Expand Down
1 change: 0 additions & 1 deletion dbms/src/Interpreters/AnalyzedJoin.h
Expand Up @@ -96,7 +96,6 @@ class AnalyzedJoin
bool hasOn() const { return table_join.on_expression != nullptr; }

NameSet getQualifiedColumnsSet() const;
NameSet getOriginalColumnsSet() const;
NamesWithAliases getNamesWithAliases(const NameSet & required_columns) const;
NamesWithAliases getRequiredColumns(const Block & sample, const Names & action_columns) const;

Expand Down
17 changes: 14 additions & 3 deletions dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp
Expand Up @@ -170,11 +170,22 @@ size_t CollectJoinOnKeysMatcher::getTableForIdentifiers(std::vector<const ASTIde
if (!membership)
{
const String & name = identifier->name;
bool in_left_table = data.source_columns.count(name);
bool in_right_table = data.joined_columns.count(name);
bool in_left_table = data.left_table.hasColumn(name);
bool in_right_table = data.right_table.hasColumn(name);

if (in_left_table && in_right_table)
throw Exception("Column '" + name + "' is ambiguous", ErrorCodes::AMBIGUOUS_COLUMN_NAME);
{
/// Relax ambiguous check for multiple JOINs
if (auto original_name = IdentifierSemantic::uncover(*identifier))
{
auto match = IdentifierSemantic::canReferColumnToTable(*original_name, data.right_table.table);
if (match == IdentifierSemantic::ColumnMatch::NoMatch)
in_right_table = false;
in_left_table = !in_right_table;
}
else
throw Exception("Column '" + name + "' is ambiguous", ErrorCodes::AMBIGUOUS_COLUMN_NAME);
}

if (in_left_table)
membership = 1;
Expand Down
5 changes: 3 additions & 2 deletions dbms/src/Interpreters/CollectJoinOnKeysVisitor.h
Expand Up @@ -3,6 +3,7 @@
#include <Core/Names.h>
#include <Parsers/ASTFunction.h>
#include <Interpreters/InDepthNodeVisitor.h>
#include <Interpreters/DatabaseAndTableWithAlias.h>
#include <Interpreters/Aliases.h>


Expand All @@ -25,8 +26,8 @@ class CollectJoinOnKeysMatcher
struct Data
{
AnalyzedJoin & analyzed_join;
const NameSet & source_columns;
const NameSet & joined_columns;
const TableWithColumnNames & left_table;
const TableWithColumnNames & right_table;
const Aliases & aliases;
const bool is_asof{false};
ASTPtr asof_left_key{};
Expand Down
14 changes: 14 additions & 0 deletions dbms/src/Interpreters/DatabaseAndTableWithAlias.h
Expand Up @@ -53,6 +53,20 @@ struct TableWithColumnNames
for (auto & column : addition)
hidden_columns.push_back(column.name);
}

bool hasColumn(const String & name) const
{
if (columns_set.empty())
{
columns_set.insert(columns.begin(), columns.end());
columns_set.insert(hidden_columns.begin(), hidden_columns.end());
}

return columns_set.count(name);
}

private:
mutable NameSet columns_set;
};

std::vector<DatabaseAndTableWithAlias> getDatabaseAndTables(const ASTSelectQuery & select_query, const String & current_database);
Expand Down
16 changes: 16 additions & 0 deletions dbms/src/Interpreters/IdentifierSemantic.cpp
Expand Up @@ -92,6 +92,22 @@ std::optional<String> IdentifierSemantic::getTableName(const ASTPtr & ast)
return {};
}

std::optional<ASTIdentifier> IdentifierSemantic::uncover(const ASTIdentifier & identifier)
{
if (identifier.semantic->covered)
{
std::vector<String> name_parts = identifier.name_parts;
return ASTIdentifier(std::move(name_parts));
}
return {};
}

void IdentifierSemantic::coverName(ASTIdentifier & identifier, const String & alias)
{
identifier.setShortName(alias);
identifier.semantic->covered = true;
}

bool IdentifierSemantic::canBeAlias(const ASTIdentifier & identifier)
{
return identifier.semantic->can_be_alias;
Expand Down
3 changes: 3 additions & 0 deletions dbms/src/Interpreters/IdentifierSemantic.h
Expand Up @@ -12,6 +12,7 @@ struct IdentifierSemanticImpl
{
bool special = false; /// for now it's 'not a column': tables, subselects and some special stuff like FORMAT
bool can_be_alias = true; /// if it's a cropped name it could not be an alias
bool covered = false; /// real (compound) name is hidden by an alias (short name)
std::optional<size_t> membership; /// table position in join
};

Expand Down Expand Up @@ -43,6 +44,8 @@ struct IdentifierSemantic
static void setColumnLongName(ASTIdentifier & identifier, const DatabaseAndTableWithAlias & db_and_table);
static bool canBeAlias(const ASTIdentifier & identifier);
static void setMembership(ASTIdentifier &, size_t table_no);
static void coverName(ASTIdentifier &, const String & alias);
static std::optional<ASTIdentifier> uncover(const ASTIdentifier & identifier);
static std::optional<size_t> getMembership(const ASTIdentifier & identifier);
static bool chooseTable(const ASTIdentifier &, const std::vector<DatabaseAndTableWithAlias> & tables, size_t & best_table_pos,
bool ambiguous = false);
Expand Down
6 changes: 3 additions & 3 deletions dbms/src/Interpreters/JoinToSubqueryTransformVisitor.cpp
Expand Up @@ -159,7 +159,7 @@ struct ColumnAliasesMatcher
aliases[alias] = long_name;
rev_aliases[long_name].push_back(alias);

identifier->setShortName(alias);
IdentifierSemantic::coverName(*identifier, alias);
if (is_public)
{
identifier->setAlias(long_name);
Expand All @@ -177,7 +177,7 @@ struct ColumnAliasesMatcher
if (is_public && allowed_long_names.count(long_name))
; /// leave original name unchanged for correct output
else
identifier->setShortName(it->second[0]);
IdentifierSemantic::coverName(*identifier, it->second[0]);
}
}
}
Expand Down Expand Up @@ -229,7 +229,7 @@ struct ColumnAliasesMatcher

if (!last_table)
{
node.setShortName(alias);
IdentifierSemantic::coverName(node, alias);
node.setAlias("");
}
}
Expand Down
11 changes: 6 additions & 5 deletions dbms/src/Interpreters/SyntaxAnalyzer.cpp
Expand Up @@ -532,8 +532,8 @@ void setJoinStrictness(ASTSelectQuery & select_query, JoinStrictness join_defaul
}

/// Find the columns that are obtained by JOIN.
void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & select_query, const NameSet & source_columns,
const Aliases & aliases)
void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & select_query,
const std::vector<TableWithColumnNames> & tables, const Aliases & aliases)
{
const ASTTablesInSelectQueryElement * node = select_query.join();
if (!node)
Expand All @@ -551,7 +551,7 @@ void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & s
{
bool is_asof = (table_join.strictness == ASTTableJoin::Strictness::Asof);

CollectJoinOnKeysVisitor::Data data{analyzed_join, source_columns, analyzed_join.getOriginalColumnsSet(), aliases, is_asof};
CollectJoinOnKeysVisitor::Data data{analyzed_join, tables[0], tables[1], aliases, is_asof};
CollectJoinOnKeysVisitor(data).visit(table_join.on_expression);
if (!data.has_some)
throw Exception("Cannot get JOIN keys from JOIN ON section: " + queryToString(table_join.on_expression),
Expand Down Expand Up @@ -820,6 +820,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(
if (storage)
collectSourceColumns(storage->getColumns(), result.source_columns, (select_query != nullptr));
NameSet source_columns_set = removeDuplicateColumns(result.source_columns);
std::vector<TableWithColumnNames> tables_with_columns;

if (select_query)
{
Expand All @@ -837,7 +838,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(
}

std::vector<const ASTTableExpression *> table_expressions = getTableExpressions(*select_query);
auto tables_with_columns = getTablesWithColumns(table_expressions, context);
tables_with_columns = getTablesWithColumns(table_expressions, context);

if (tables_with_columns.empty())
{
Expand Down Expand Up @@ -935,7 +936,7 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyze(

setJoinStrictness(*select_query, settings.join_default_strictness, settings.any_join_distinct_right_table_keys,
result.analyzed_join->table_join);
collectJoinedColumns(*result.analyzed_join, *select_query, source_columns_set, result.aliases);
collectJoinedColumns(*result.analyzed_join, *select_query, tables_with_columns, result.aliases);
}

result.aggregates = getAggregates(query);
Expand Down
@@ -0,0 +1,29 @@
DROP TABLE IF EXISTS a;
DROP TABLE IF EXISTS b;
DROP TABLE IF EXISTS c;

CREATE TABLE a (x UInt64) ENGINE = Memory;
CREATE TABLE b (x UInt64) ENGINE = Memory;
CREATE TABLE c (x UInt64) ENGINE = Memory;

SET enable_optimize_predicate_expression = 0;

SELECT a.x AS x FROM a
LEFT JOIN b ON a.x = b.x
LEFT JOIN c ON a.x = c.x;

SELECT a.x AS x FROM a
LEFT JOIN b ON a.x = b.x
LEFT JOIN c ON b.x = c.x;

SELECT b.x AS x FROM a
LEFT JOIN b ON a.x = b.x
LEFT JOIN c ON b.x = c.x;

SELECT c.x AS x FROM a
LEFT JOIN b ON a.x = b.x
LEFT JOIN c ON b.x = c.x;

DROP TABLE a;
DROP TABLE b;
DROP TABLE c;