Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor Visitors Improvement #7411

Merged
merged 1 commit into from
Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions dbms/src/Interpreters/CrossToInnerJoinVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class CheckExpressionVisitorData
, ands_only(true)
{}

void visit(const ASTFunction & node, ASTPtr & ast)
void visit(const ASTFunction & node, const ASTPtr & ast)
{
if (!ands_only)
return;
Expand Down Expand Up @@ -231,8 +231,8 @@ class CheckExpressionVisitorData
}
};

using CheckExpressionMatcher = OneTypeMatcher<CheckExpressionVisitorData, false>;
using CheckExpressionVisitor = InDepthNodeVisitor<CheckExpressionMatcher, true>;
using CheckExpressionMatcher = ConstOneTypeMatcher<CheckExpressionVisitorData, false>;
using CheckExpressionVisitor = ConstInDepthNodeVisitor<CheckExpressionMatcher, true>;


bool getTables(ASTSelectQuery & select, std::vector<JoinedTable> & joined_tables, size_t & num_comma)
Expand Down Expand Up @@ -314,7 +314,7 @@ void CrossToInnerJoinMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & da
return;

CheckExpressionVisitor::Data visitor_data{joined_tables};
CheckExpressionVisitor(visitor_data).visit(select.refWhere());
CheckExpressionVisitor(visitor_data).visit(select.where());

if (visitor_data.complex())
return;
Expand Down
36 changes: 10 additions & 26 deletions dbms/src/Interpreters/InDepthNodeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ namespace DB

/// Visits AST tree in depth, call functions for nodes according to Matcher type data.
/// You need to define Data, visit() and needChildVisit() in Matcher class.
template <typename Matcher, bool _top_to_bottom, typename T>
class InDepthNodeVisitorTemplate
template <typename Matcher, bool _top_to_bottom, typename T = ASTPtr>
class InDepthNodeVisitor
{
public:
using Data = typename Matcher::Data;

InDepthNodeVisitorTemplate(Data & data_, std::ostream * ostr_ = nullptr)
InDepthNodeVisitor(Data & data_, std::ostream * ostr_ = nullptr)
: data(data_),
visit_depth(0),
ostr(ostr_)
Expand Down Expand Up @@ -49,42 +49,26 @@ class InDepthNodeVisitorTemplate
};

template <typename Matcher, bool top_to_bottom>
using InDepthNodeVisitor = InDepthNodeVisitorTemplate<Matcher, top_to_bottom, ASTPtr>;

template <typename Matcher, bool top_to_bottom>
using ConstInDepthNodeVisitor = InDepthNodeVisitorTemplate<Matcher, top_to_bottom, const ASTPtr>;
using ConstInDepthNodeVisitor = InDepthNodeVisitor<Matcher, top_to_bottom, const ASTPtr>;

/// Simple matcher for one node type without complex traversal logic.
template <typename _Data, bool _visit_children = true>
template <typename Data_, bool visit_children = true, typename T = ASTPtr>
class OneTypeMatcher
{
public:
using Data = _Data;
using Data = Data_;
using TypeToVisit = typename Data::TypeToVisit;

static bool needChildVisit(ASTPtr &, const ASTPtr &) { return _visit_children; }
static bool needChildVisit(const ASTPtr &, const ASTPtr &) { return visit_children; }

static void visit(ASTPtr & ast, Data & data)
static void visit(T & ast, Data & data)
{
if (auto * t = typeid_cast<TypeToVisit *>(ast.get()))
data.visit(*t, ast);
}
};

/// Links two simple matches into resulting one. There's no complex traversal logic: all the children would be visited.
template <typename First, typename Second>
class LinkedMatcher
{
public:
using Data = std::pair<typename First::Data, typename Second::Data>;

static bool needChildVisit(ASTPtr &, const ASTPtr &) { return true; }

static void visit(ASTPtr & ast, Data & data)
{
First::visit(ast, data.first);
Second::visit(ast, data.second);
}
};
template <typename Data, bool visit_children = true>
using ConstOneTypeMatcher = OneTypeMatcher<Data, visit_children, const ASTPtr>;

}
33 changes: 11 additions & 22 deletions dbms/src/Interpreters/JoinToSubqueryTransformVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ namespace
class ExtractAsterisksMatcher
{
public:
using Visitor = InDepthNodeVisitor<ExtractAsterisksMatcher, true>;

struct Data
{
std::unordered_map<String, NamesAndTypesList> table_columns;
Expand Down Expand Up @@ -76,30 +74,16 @@ class ExtractAsterisksMatcher
}
};

static bool needChildVisit(ASTPtr &, const ASTPtr &) { return false; }
static bool needChildVisit(const ASTPtr &, const ASTPtr &) { return false; }

static void visit(ASTPtr & ast, Data & data)
static void visit(const ASTPtr & ast, Data & data)
{
if (auto * t = ast->as<ASTSelectQuery>())
visit(*t, ast, data);
if (auto * t = ast->as<ASTExpressionList>())
visit(*t, ast, data);
}

private:
static void visit(ASTSelectQuery & node, ASTPtr &, Data & data)
{
if (data.table_columns.empty())
return;

Visitor(data).visit(node.refSelect());
if (!data.new_select_expression_list)
return;

node.setExpression(ASTSelectQuery::Expression::SELECT, std::move(data.new_select_expression_list));
}

static void visit(ASTExpressionList & node, ASTPtr &, Data & data)
static void visit(const ASTExpressionList & node, const ASTPtr &, Data & data)
{
bool has_asterisks = false;
data.new_select_expression_list = std::make_shared<ASTExpressionList>();
Expand Down Expand Up @@ -375,7 +359,7 @@ using RewriteMatcher = OneTypeMatcher<RewriteTablesVisitorData>;
using RewriteVisitor = InDepthNodeVisitor<RewriteMatcher, true>;
using SetSubqueryAliasMatcher = OneTypeMatcher<SetSubqueryAliasVisitorData>;
using SetSubqueryAliasVisitor = InDepthNodeVisitor<SetSubqueryAliasMatcher, true>;
using ExtractAsterisksVisitor = ExtractAsterisksMatcher::Visitor;
using ExtractAsterisksVisitor = ConstInDepthNodeVisitor<ExtractAsterisksMatcher, true>;
using ColumnAliasesVisitor = ConstInDepthNodeVisitor<ColumnAliasesMatcher, true>;
using AppendSemanticMatcher = OneTypeMatcher<AppendSemanticVisitorData>;
using AppendSemanticVisitor = InDepthNodeVisitor<AppendSemanticMatcher, true>;
Expand All @@ -389,7 +373,7 @@ void JoinToSubqueryTransformMatcher::visit(ASTPtr & ast, Data & data)
visit(*t, ast, data);
}

void JoinToSubqueryTransformMatcher::visit(ASTSelectQuery & select, ASTPtr & ast, Data & data)
void JoinToSubqueryTransformMatcher::visit(ASTSelectQuery & select, ASTPtr &, Data & data)
{
using RevertedAliases = AsteriskSemantic::RevertedAliases;

Expand All @@ -398,7 +382,12 @@ void JoinToSubqueryTransformMatcher::visit(ASTSelectQuery & select, ASTPtr & ast
return;

ExtractAsterisksVisitor::Data asterisks_data(data.context, table_expressions);
ExtractAsterisksVisitor(asterisks_data).visit(ast);
if (!asterisks_data.table_columns.empty())
{
ExtractAsterisksVisitor(asterisks_data).visit(select.select());
if (asterisks_data.new_select_expression_list)
select.setExpression(ASTSelectQuery::Expression::SELECT, std::move(asterisks_data.new_select_expression_list));
}

ColumnAliasesVisitor::Data aliases_data(getDatabaseAndTables(select, ""));
if (select.select())
Expand Down